diff --git a/src/win/conpty.cc b/src/win/conpty.cc index 633102d7f..66741d9d7 100644 --- a/src/win/conpty.cc +++ b/src/win/conpty.cc @@ -14,6 +14,8 @@ #include #include #include // PathCombine, PathIsRelative +#include +#include #include #include #include @@ -50,9 +52,12 @@ struct pty_baton { }; static std::vector> ptyHandles; -static volatile LONG ptyCounter; +static std::mutex g_ptyHandlesMutex; +static std::atomic ptyCounter{0}; -static pty_baton* get_pty_baton(int id) { +// The leading scoped-lock parameter encodes the precondition that the caller +// holds g_ptyHandlesMutex. +static pty_baton* get_pty_baton(const std::lock_guard&, int id) { auto it = std::find_if(ptyHandles.begin(), ptyHandles.end(), [id](const auto& ptyHandle) { return ptyHandle->id == id; }); @@ -62,17 +67,6 @@ static pty_baton* get_pty_baton(int id) { return nullptr; } -static bool remove_pty_baton(int id) { - auto it = std::remove_if(ptyHandles.begin(), ptyHandles.end(), [id](const auto& ptyHandle) { - return ptyHandle->id == id; - }); - if (it != ptyHandles.end()) { - ptyHandles.erase(it); - return true; - } - return false; -} - struct ExitEvent { int exit_code = 0; }; @@ -99,11 +93,15 @@ void SetupExitCallback(Napi::Env env, Napi::Function cb, pty_baton* baton) { ExitEvent *exit_event = new ExitEvent; // Wait for process to complete. WaitForSingleObject(baton->hShell, INFINITE); - // Get process exit code. - GetExitCodeProcess(baton->hShell, (LPDWORD)(&exit_event->exit_code)); - // Clean up handles - CloseHandle(baton->hShell); - assert(remove_pty_baton(baton->id)); + { + std::lock_guard lock(g_ptyHandlesMutex); + GetExitCodeProcess(baton->hShell, (LPDWORD)(&exit_event->exit_code)); + CloseHandle(baton->hShell); + const int id = baton->id; + std::erase_if(ptyHandles, [id](const auto& ptyHandle) { + return ptyHandle->id == id; + }); + } auto status = tsfn.BlockingCall(exit_event, callback); // In main thread switch (status) { @@ -298,10 +296,13 @@ static Napi::Value PtyStartProcess(const Napi::CallbackInfo& info) { if (SUCCEEDED(hr)) { // We were able to instantiate a conpty - const int ptyId = InterlockedIncrement(&ptyCounter); + const int ptyId = ++ptyCounter; marshal.Set("pty", Napi::Number::New(env, ptyId)); - ptyHandles.emplace_back( - std::make_unique(ptyId, hIn, hOut, hpc)); + { + std::lock_guard lock(g_ptyHandlesMutex); + ptyHandles.emplace_back( + std::make_unique(ptyId, hIn, hOut, hpc)); + } } else { throw Napi::Error::New(env, "Cannot launch conpty"); } @@ -349,10 +350,13 @@ static Napi::Value PtyConnect(const Napi::CallbackInfo& info) { const bool useConptyDll = info[4].As().Value(); Napi::Function exitCallback = info[5].As(); - // Fetch pty handle from ID and start process - pty_baton* handle = get_pty_baton(id); - if (!handle) { - throw Napi::Error::New(env, "Invalid pty handle"); + pty_baton* handle; + { + std::lock_guard lock(g_ptyHandlesMutex); + handle = get_pty_baton(lock, id); + if (!handle) { + throw Napi::Error::New(env, "Invalid pty handle"); + } } // Prepare command line @@ -471,7 +475,8 @@ static Napi::Value PtyResize(const Napi::CallbackInfo& info) { SHORT rows = static_cast(info[2].As().Uint32Value()); const bool useConptyDll = info[3].As().Value(); - const pty_baton* handle = get_pty_baton(id); + std::lock_guard lock(g_ptyHandlesMutex); + const pty_baton* handle = get_pty_baton(lock, id); if (handle != nullptr) { HANDLE hLibrary = LoadConptyDll(info, useConptyDll); @@ -512,7 +517,8 @@ static Napi::Value PtyClear(const Napi::CallbackInfo& info) { return env.Undefined(); } - const pty_baton* handle = get_pty_baton(id); + std::lock_guard lock(g_ptyHandlesMutex); + const pty_baton* handle = get_pty_baton(lock, id); if (handle != nullptr) { HANDLE hLibrary = LoadConptyDll(info, useConptyDll); @@ -543,7 +549,8 @@ static Napi::Value PtyKill(const Napi::CallbackInfo& info) { int id = info[0].As().Int32Value(); const bool useConptyDll = info[1].As().Value(); - const pty_baton* handle = get_pty_baton(id); + std::lock_guard lock(g_ptyHandlesMutex); + const pty_baton* handle = get_pty_baton(lock, id); if (handle != nullptr) { HANDLE hLibrary = LoadConptyDll(info, useConptyDll); diff --git a/src/windowsTerminal.test.ts b/src/windowsTerminal.test.ts index 28c102bcb..303086ae4 100644 --- a/src/windowsTerminal.test.ts +++ b/src/windowsTerminal.test.ts @@ -268,6 +268,58 @@ if (process.platform === 'win32') { }); }); }); + + describe('Regression for #921', () => { + it('should not crash with concurrent kills while resizing/clearing', function (done) { + this.timeout(60000); + const N = 30; + const terms: WindowsTerminal[] = []; + let ready = 0; + let exited = 0; + let spamInterval: NodeJS.Timeout | undefined; + const cleanup = (err?: Error): void => { + if (spamInterval) { + clearInterval(spamInterval); + spamInterval = undefined; + } + done(err); + }; + const startRace = (): void => { + spamInterval = setInterval(() => { + for (const t of terms) { + try { + t.resize(80 + Math.floor(Math.random() * 40), 24 + Math.floor(Math.random() * 20)); + } catch (e) { /* already exited */ } + try { + t.clear(); + } catch (e) { /* already exited */ } + } + }, 1); + for (const t of terms) { + try { t.kill(); } catch (e) { /* */ } + } + }; + for (let i = 0; i < N; i++) { + const t = new WindowsTerminal('cmd.exe', [], { useConptyDll }); + terms.push(t); + let readied = false; + t.on('data', () => { + if (readied) return; + readied = true; + ready++; + if (ready === N) { + startRace(); + } + }); + t.on('exit', () => { + exited++; + if (exited === N) { + cleanup(); + } + }); + } + }); + }); }); }); }