diff --git a/src/MIDebugEngine/AD7.Impl/AD7Engine.cs b/src/MIDebugEngine/AD7.Impl/AD7Engine.cs index b9e60a869..ffb8146b8 100755 --- a/src/MIDebugEngine/AD7.Impl/AD7Engine.cs +++ b/src/MIDebugEngine/AD7.Impl/AD7Engine.cs @@ -229,6 +229,7 @@ public int Attach(IDebugProgram2[] portProgramArray, IDebugProgramNode2[] progra if (port is IDebugUnixShellPort) { _unixPort = (IDebugUnixShellPort)port; + (_unixPort as IDebugPortCleanup)?.AddSessionRef(); } StartDebugging(launchOptions); } diff --git a/src/Microsoft.VisualStudio.Debugger.Interop.UnixPortSupplier/Microsoft.VisualStudio.Debugger.Interop.UnixPortSupplier.cs b/src/Microsoft.VisualStudio.Debugger.Interop.UnixPortSupplier/Microsoft.VisualStudio.Debugger.Interop.UnixPortSupplier.cs index 87f9da2d6..7ee30a393 100644 --- a/src/Microsoft.VisualStudio.Debugger.Interop.UnixPortSupplier/Microsoft.VisualStudio.Debugger.Interop.UnixPortSupplier.cs +++ b/src/Microsoft.VisualStudio.Debugger.Interop.UnixPortSupplier/Microsoft.VisualStudio.Debugger.Interop.UnixPortSupplier.cs @@ -73,7 +73,7 @@ public interface IDebugUnixShellPort } /// - /// Interface implemented by a port that supports explicit cleanup + /// Interface implemented by a port that supports explicit cleanup of shared connections. /// [ComImport()] [ComVisible(true)] @@ -82,9 +82,14 @@ public interface IDebugUnixShellPort public interface IDebugPortCleanup { /// - /// Clean up debugging resources + /// Decrement the session reference count and close the connection when it reaches zero. /// void Clean(); + + /// + /// Increment the session reference count on this port. + /// + void AddSessionRef(); } /// diff --git a/src/SSHDebugPS/AD7/AD7Port.cs b/src/SSHDebugPS/AD7/AD7Port.cs index 819f63eda..951887bbd 100644 --- a/src/SSHDebugPS/AD7/AD7Port.cs +++ b/src/SSHDebugPS/AD7/AD7Port.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Globalization; using System.Linq; using System.Threading; @@ -21,6 +22,7 @@ internal abstract class AD7Port : IDebugPort2, IDebugUnixShellPort, IDebugPortCl private Connection _connection; private readonly Dictionary _eventCallbacks = new Dictionary(); private uint _lastCallbackCookie; + private int _refCount; protected string Name { get; private set; } @@ -37,16 +39,19 @@ public AD7Port(AD7PortSupplier portSupplier, string name, bool isInAddPort) protected Connection GetConnection() { - if (_connection == null) + lock (_lock) { - _connection = GetConnectionInternal(); - if (_connection != null) + if (_connection == null) { - Name = _connection.Name; + _connection = GetConnectionInternal(); + if (_connection != null) + { + Name = _connection.Name; + } } - } - return _connection; + return _connection; + } } protected abstract Connection GetConnectionInternal(); @@ -60,7 +65,10 @@ public bool IsConnected { get { - return _connection != null; + lock (_lock) + { + return _connection != null; + } } } @@ -88,6 +96,8 @@ private AD7Process[] EnumProcessesInternal() result = processList.Select((proc) => new AD7Process(this, proc)).ToArray(); }); + CloseConnectionIfIdle(); + return result; } @@ -155,7 +165,115 @@ void IDebugUnixShellPort.ExecuteSyncCommand(string commandDescription, string co void IDebugUnixShellPort.BeginExecuteAsyncCommand(string commandText, bool runInShell, IDebugUnixShellCommandCallback callback, out IDebugUnixShellAsyncCommand asyncCommand) { - GetConnection().BeginExecuteAsyncCommand(commandText, runInShell, callback, out asyncCommand); + var wrappedCallback = new AsyncCommandCallback(this, callback); + lock (_lock) + { + _refCount++; + } + try + { + var connection = GetConnection(); + connection.BeginExecuteAsyncCommand(commandText, runInShell, wrappedCallback, out asyncCommand); + asyncCommand = new AsyncCommandWrapper(wrappedCallback, asyncCommand); + } + catch + { + lock (_lock) + { + _refCount--; + } + CloseConnectionIfIdle(); + throw; + } + } + + /// + /// Wraps IDebugUnixShellAsyncCommand so that Abort() triggers NotifyExited on the + /// callback, ensuring the ref count is decremented even when OnExit does not fire. + /// + private class AsyncCommandWrapper : IDebugUnixShellAsyncCommand + { + private readonly AsyncCommandCallback _callback; + private readonly IDebugUnixShellAsyncCommand _inner; + + public AsyncCommandWrapper(AsyncCommandCallback callback, IDebugUnixShellAsyncCommand inner) + { + _callback = callback; + _inner = inner; + } + + public void Write(string text) => _inner.Write(text); + public void WriteLine(string text) => _inner.WriteLine(text); + + public void Abort() + { + _inner.Abort(); + _callback.NotifyExited(); + } + } + + private void OnAsyncCommandExited() + { + lock (_lock) + { + Debug.Assert(_refCount > 0, "Underflowing _refCount"); + _refCount--; + } + CloseConnectionIfIdle(); + } + + private void CloseConnectionIfIdle() + { + lock (_lock) + { + if (_refCount > 0 || _connection == null) + { + return; + } + + var conn = _connection; + _connection = null; + try { + conn.Close(); + } + // Dev15 632648: Liblinux sometimes throws exceptions on shutdown - we are shutting down anyways, so ignore to not crash + catch (Exception) { } + } + } + + /// + /// Wraps an IDebugUnixShellCommandCallback to track async command exits. + /// + private class AsyncCommandCallback : IDebugUnixShellCommandCallback + { + private readonly AD7Port _port; + private readonly IDebugUnixShellCommandCallback _inner; + private int _notified; + + public AsyncCommandCallback(AD7Port port, IDebugUnixShellCommandCallback inner) + { + _port = port; + _inner = inner; + } + + public void OnOutputLine(string line) + { + _inner.OnOutputLine(line); + } + + public void OnExit(string exitCode) + { + _inner.OnExit(exitCode); + NotifyExited(); + } + + public void NotifyExited() + { + if (Interlocked.CompareExchange(ref _notified, 1, 0) == 0) + { + _port.OnAsyncCommandExited(); + } + } } void IConnectionPointContainer.EnumConnectionPoints(out IEnumConnectionPoints ppEnum) @@ -241,14 +359,23 @@ public bool IsLinux() return GetConnection().IsLinux(); } + public void AddSessionRef() + { + lock (_lock) + { + _refCount++; + } + EnsureConnected(); + } + public void Clean() { - try + lock (_lock) { - _connection?.Close(); + Debug.Assert(_refCount > 0, "Underflowing _refCount"); + _refCount--; } - // Dev15 632648: Liblinux sometimes throws exceptions on shutdown - we are shutting down anyways, so ignore to not crash - catch (Exception) { } + CloseConnectionIfIdle(); } } } diff --git a/src/SSHDebugPS/SSH/SSHHelper.cs b/src/SSHDebugPS/SSH/SSHHelper.cs index de1e0d3ea..f506d842b 100644 --- a/src/SSHDebugPS/SSH/SSHHelper.cs +++ b/src/SSHDebugPS/SSH/SSHHelper.cs @@ -23,53 +23,61 @@ internal static SSHConnection CreateSSHConnectionFromConnectionInfo(ConnectionIn if (connectionInfo != null) { UnixSystem remoteSystem = new UnixSystem(); - string name = SSHPortSupplier.GetFormattedSSHConnectionName(connectionInfo); - - while (true) + bool success = false; + try { - try - { - VSOperationWaiter.Wait( - StringResources.WaitingOp_Connecting.FormatCurrentCultureWithArgs(name), - throwOnCancel: false, - action: (cancellationToken) => - remoteSystem.Connect(connectionInfo)); - break; - } - catch (RemoteAuthenticationException) + string name = SSHPortSupplier.GetFormattedSSHConnectionName(connectionInfo); + + while (true) { - IVsConnectionManager connectionManager = (IVsConnectionManager)ServiceProvider.GlobalProvider.GetService(typeof(IVsConnectionManager)); - if (connectionManager != null) + try { - IConnectionManagerResult result = connectionManager.ShowDialog(StringResources.AuthenticationFailureHeader, StringResources.AuthenticationFailureDescription, connectionInfo); - - if (result != null && (result.DialogResult & ConnectionManagerDialogResult.Succeeded) == ConnectionManagerDialogResult.Succeeded) + VSOperationWaiter.Wait( + StringResources.WaitingOp_Connecting.FormatCurrentCultureWithArgs(name), + throwOnCancel: false, + action: (cancellationToken) => + remoteSystem.Connect(connectionInfo)); + break; + } + catch (RemoteAuthenticationException) + { + IVsConnectionManager connectionManager = (IVsConnectionManager)ServiceProvider.GlobalProvider.GetService(typeof(IVsConnectionManager)); + if (connectionManager != null) { - connectionInfo = result.ConnectionInfo; + IConnectionManagerResult result = connectionManager.ShowDialog(StringResources.AuthenticationFailureHeader, StringResources.AuthenticationFailureDescription, connectionInfo); + + if (result != null && (result.DialogResult & ConnectionManagerDialogResult.Succeeded) == ConnectionManagerDialogResult.Succeeded) + { + connectionInfo = result.ConnectionInfo; + } + else + { + return null; + } } else { - return null; + throw new InvalidOperationException("Why is IVsConnectionManager null?"); } } - else + catch (Exception ex) { - throw new InvalidOperationException("Why is IVsConnectionManager null?"); + VsShellUtilities.ShowMessageBox(ServiceProvider.GlobalProvider, ex.Message, null, + OLEMSGICON.OLEMSGICON_CRITICAL, OLEMSGBUTTON.OLEMSGBUTTON_OK, OLEMSGDEFBUTTON.OLEMSGDEFBUTTON_FIRST); + return null; } } - catch (Exception ex) - { - VsShellUtilities.ShowMessageBox(ServiceProvider.GlobalProvider, ex.Message, null, - OLEMSGICON.OLEMSGICON_CRITICAL, OLEMSGBUTTON.OLEMSGBUTTON_OK, OLEMSGDEFBUTTON.OLEMSGDEFBUTTON_FIRST); - return null; - } - } - // NOTE: This will be null if connect is canceled - if (remoteSystem != null) - { + success = true; return new SSHConnection(remoteSystem); } + finally + { + if (!success) + { + remoteSystem.Dispose(); + } + } } return null; diff --git a/src/SSHDebugPS/SSH/SSHPortSupplier.cs b/src/SSHDebugPS/SSH/SSHPortSupplier.cs index 349292ee8..ef60efb59 100644 --- a/src/SSHDebugPS/SSH/SSHPortSupplier.cs +++ b/src/SSHDebugPS/SSH/SSHPortSupplier.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; using liblinux; @@ -20,6 +21,8 @@ internal class SSHPortSupplier : AD7PortSupplier { private const string _Name = "SSH"; private readonly Guid _Id = new Guid("3FDDF14E-E758-4695-BE0C-7509920432C9"); + private readonly object _portCacheLock = new object(); + private readonly Dictionary _portCache = new Dictionary(StringComparer.OrdinalIgnoreCase); protected override Guid Id { get { return _Id; } } protected override string Name { get { return _Name; } } @@ -33,11 +36,26 @@ public override int AddPort(IDebugPortRequest2 request, out IDebugPort2 port) string name; HR.Check(request.GetPortName(out name)); - AD7Port newPort = new SSHPort(this, name, isInAddPort: true); + SSHPort sshPort; + lock (_portCacheLock) + { + if (!_portCache.TryGetValue(name, out sshPort)) + { + sshPort = new SSHPort(this, name, isInAddPort: true); + if (sshPort.IsConnected) + { + _portCache[name] = sshPort; + } + } + else + { + sshPort.EnsureConnected(); + } + } - if (newPort.IsConnected) + if (sshPort.IsConnected) { - port = newPort; + port = sshPort; return HR.S_OK; } @@ -53,7 +71,18 @@ public override int EnumPorts(out IEnumDebugPorts2 ppEnum) for (int i = 0; i < store.Connections.Count; i++) { ConnectionInfo connectionInfo = (ConnectionInfo)store.Connections[i]; - ports[i] = new SSHPort(this, GetFormattedSSHConnectionName(connectionInfo), isInAddPort: false); + string name = GetFormattedSSHConnectionName(connectionInfo); + + lock (_portCacheLock) + { + if (!_portCache.TryGetValue(name, out SSHPort port)) + { + port = new SSHPort(this, name, isInAddPort: false); + _portCache[name] = port; + } + + ports[i] = port; + } } ppEnum = new AD7PortEnum(ports);