diff --git a/change/react-native-windows-c3827e14-777b-475a-bf00-dc169bf89f3d.json b/change/react-native-windows-c3827e14-777b-475a-bf00-dc169bf89f3d.json new file mode 100644 index 00000000000..e48ac161fcd --- /dev/null +++ b/change/react-native-windows-c3827e14-777b-475a-bf00-dc169bf89f3d.json @@ -0,0 +1,7 @@ +{ + "type": "prerelease", + "comment": "Fix WebSocket binaryType handling — stop unconditional Blob interception of binary messages", + "packageName": "react-native-windows", + "email": "gordomacmaster@gmail.com", + "dependentChangeType": "patch" +} \ No newline at end of file diff --git a/vnext/Desktop.IntegrationTests/RNTesterHeadlessTests.cpp b/vnext/Desktop.IntegrationTests/RNTesterHeadlessTests.cpp index f25eb3cdd4d..5948e5b9fb8 100644 --- a/vnext/Desktop.IntegrationTests/RNTesterHeadlessTests.cpp +++ b/vnext/Desktop.IntegrationTests/RNTesterHeadlessTests.cpp @@ -63,6 +63,36 @@ TEST_CLASS (RNTesterHeadlessTests) { auto status = TestModule::AwaitCompletion(); Assert::IsTrue(status == TestStatus::Passed, L"Test did not pass (JS did not call markTestPassed within timeout)"); } + + BEGIN_TEST_METHOD_ATTRIBUTE(WebSocketArrayBuffer) + TEST_IGNORE() + END_TEST_METHOD_ATTRIBUTE() + TEST_METHOD(WebSocketArrayBuffer) { + TestModule::Reset(); + + winrt::handle instanceLoadedEvent{CreateEvent(nullptr, TRUE, FALSE, nullptr)}; + bool instanceFailed{false}; + + auto holder = TestReactNativeHostHolder( + L"IntegrationTests/WebSocketArrayBufferTest", + [&instanceLoadedEvent, &instanceFailed](msrn::ReactNativeHost const &host) noexcept { + host.InstanceSettings().InstanceLoaded( + [&instanceLoadedEvent, &instanceFailed](auto const &, msrn::InstanceLoadedEventArgs args) noexcept { + instanceFailed = args.Failed(); + SetEvent(instanceLoadedEvent.get()); + }); + }); + + WaitForSingleObject(instanceLoadedEvent.get(), INFINITE); + if (instanceFailed) { + auto err = holder.GetLastError(); + auto msg = L"InstanceLoaded reported failure: " + (err.empty() ? L"(no error captured)" : err); + Assert::Fail(msg.c_str()); + } + + auto status = TestModule::AwaitCompletion(); + Assert::IsTrue(status == TestStatus::Passed, L"Test did not pass (JS did not call markTestPassed within timeout)"); + } }; } // namespace Microsoft::React::Test diff --git a/vnext/Shared/Modules/IWebSocketModuleContentHandler.h b/vnext/Shared/Modules/IWebSocketModuleContentHandler.h index 4d508603865..f64b0fc65ab 100644 --- a/vnext/Shared/Modules/IWebSocketModuleContentHandler.h +++ b/vnext/Shared/Modules/IWebSocketModuleContentHandler.h @@ -18,11 +18,49 @@ namespace Microsoft::React { struct IWebSocketModuleContentHandler { virtual ~IWebSocketModuleContentHandler() noexcept {} + /// Returns true if this handler should process messages for the given socket. + /// Default returns true for backward compatibility; BlobModule overrides to + /// check whether binaryType='blob' was set for this socket via addWebSocketHandler. + /// + /// WARNING: Subclasses that override Supports() with a stateful or lock-protected + /// check MUST also override both TryProcessMessage() overloads to perform the + /// check-and-process atomically. The default TryProcessMessage() calls Supports() + /// and ProcessMessage() as two separate operations with no lock held between them. + virtual bool Supports(int64_t /*socketId*/) noexcept { + return true; + } + virtual void ProcessMessage(std::string &&message, winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept = 0; virtual void ProcessMessage( std::vector &&message, winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept = 0; + + /// Check Supports() then ProcessMessage() in one call. + /// Returns true if the message was handled. + /// + /// The default implementation does NOT hold any lock across both operations. + /// Subclasses with a stateful Supports() MUST override these to make the + /// check-and-process atomic (see BlobWebSocketModuleContentHandler for an example). + virtual bool TryProcessMessage( + int64_t socketId, + std::string &&message, + winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept { + if (!Supports(socketId)) + return false; + ProcessMessage(std::move(message), params); + return true; + } + + virtual bool TryProcessMessage( + int64_t socketId, + std::vector &&message, + winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept { + if (!Supports(socketId)) + return false; + ProcessMessage(std::move(message), params); + return true; + } }; } // namespace Microsoft::React diff --git a/vnext/Shared/Modules/WebSocketModule.cpp b/vnext/Shared/Modules/WebSocketModule.cpp index 7b0f5bb12a9..6bbc800207f 100644 --- a/vnext/Shared/Modules/WebSocketModule.cpp +++ b/vnext/Shared/Modules/WebSocketModule.cpp @@ -83,6 +83,7 @@ shared_ptr WebSocketTurboModule::CreateResource(int64_t id, if (auto prop = propBag.Get(BlobModuleContentHandlerPropertyId())) contentHandler = prop.Value().lock(); + bool handled = false; if (contentHandler) { if (isBinary) { auto buffer = CryptographicBuffer::DecodeFromBase64String(winrt::to_hstring(message)); @@ -90,11 +91,12 @@ shared_ptr WebSocketTurboModule::CreateResource(int64_t id, CryptographicBuffer::CopyToByteArray(buffer, arr); auto data = vector(arr.begin(), arr.end()); - contentHandler->ProcessMessage(std::move(data), args); + handled = contentHandler->TryProcessMessage(id, std::move(data), args); } else { - contentHandler->ProcessMessage(string{message}, args); + handled = contentHandler->TryProcessMessage(id, string{message}, args); } - } else { + } + if (!handled) { args["data"] = message; } diff --git a/vnext/Shared/Networking/DefaultBlobResource.cpp b/vnext/Shared/Networking/DefaultBlobResource.cpp index 31fdfd6b061..a774d3b700f 100644 --- a/vnext/Shared/Networking/DefaultBlobResource.cpp +++ b/vnext/Shared/Networking/DefaultBlobResource.cpp @@ -221,6 +221,11 @@ BlobWebSocketModuleContentHandler::BlobWebSocketModuleContentHandler(shared_ptr< #pragma region IWebSocketModuleContentHandler +bool BlobWebSocketModuleContentHandler::Supports(int64_t socketId) noexcept /*override*/ { + scoped_lock lock{m_mutex}; + return m_socketIds.find(socketId) != m_socketIds.end(); +} + void BlobWebSocketModuleContentHandler::ProcessMessage( string &&message, msrn::JSValueObject ¶ms) noexcept /*override*/ @@ -241,6 +246,38 @@ void BlobWebSocketModuleContentHandler::ProcessMessage( params[blobKeys.Type] = blobKeys.Blob; } +bool BlobWebSocketModuleContentHandler::TryProcessMessage( + int64_t socketId, + string &&message, + msrn::JSValueObject ¶ms) noexcept /*override*/ +{ + scoped_lock lock{m_mutex}; + if (m_socketIds.find(socketId) == m_socketIds.end()) + return false; + + params[blobKeys.Data] = std::move(message); + return true; +} + +bool BlobWebSocketModuleContentHandler::TryProcessMessage( + int64_t socketId, + vector &&message, + msrn::JSValueObject ¶ms) noexcept /*override*/ +{ + scoped_lock lock{m_mutex}; + if (m_socketIds.find(socketId) == m_socketIds.end()) + return false; + + auto blob = msrn::JSValueObject{ + {blobKeys.Offset, 0}, + {blobKeys.Size, message.size()}, + {blobKeys.BlobId, m_blobPersistor->StoreMessage(std::move(message))}}; + + params[blobKeys.Data] = std::move(blob); + params[blobKeys.Type] = blobKeys.Blob; + return true; +} + #pragma endregion IWebSocketModuleContentHandler void BlobWebSocketModuleContentHandler::Register(int64_t socketID) noexcept { diff --git a/vnext/Shared/Networking/DefaultBlobResource.h b/vnext/Shared/Networking/DefaultBlobResource.h index 4dfdf5f18aa..68f5903ce47 100644 --- a/vnext/Shared/Networking/DefaultBlobResource.h +++ b/vnext/Shared/Networking/DefaultBlobResource.h @@ -51,11 +51,23 @@ class BlobWebSocketModuleContentHandler final : public IWebSocketModuleContentHa #pragma region IWebSocketModuleContentHandler + bool Supports(int64_t socketId) noexcept override; + void ProcessMessage(std::string &&message, winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept override; void ProcessMessage(std::vector &&message, winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept override; + bool TryProcessMessage( + int64_t socketId, + std::string &&message, + winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept override; + + bool TryProcessMessage( + int64_t socketId, + std::vector &&message, + winrt::Microsoft::ReactNative::JSValueObject ¶ms) noexcept override; + #pragma endregion IWebSocketModuleContentHandler void Register(int64_t socketID) noexcept; diff --git a/vnext/src-win/IntegrationTests/WebSocketArrayBufferTest.js b/vnext/src-win/IntegrationTests/WebSocketArrayBufferTest.js new file mode 100644 index 00000000000..50360b90b55 --- /dev/null +++ b/vnext/src-win/IntegrationTests/WebSocketArrayBufferTest.js @@ -0,0 +1,76 @@ +/** + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT License. + * @format + */ + +'use strict'; + +const {TurboModuleRegistry} = require('react-native'); +const TestModule = TurboModuleRegistry.get('TestModule'); + +if (!TestModule) { + throw new Error('TestModule is not available'); +} + +// eslint-disable-next-line @microsoft/sdl/no-insecure-url +const WS_URL = 'ws://localhost:5555/rnw/rntester/websocketbinarytest'; + +const socket = new WebSocket(WS_URL); +socket.binaryType = 'arraybuffer'; + +socket.addEventListener('open', () => { + socket.send('hello'); +}); + +socket.addEventListener('message', event => { + const data = event.data; + + if (!(data instanceof ArrayBuffer)) { + console.log( + 'WebSocketArrayBufferTest FAIL: expected ArrayBuffer, got ' + typeof data, + ); + TestModule.markTestPassed(false); + socket.close(); + return; + } + + const bytes = new Uint8Array(data); + const expected = new Uint8Array([4, 5, 6, 7]); + + if (bytes.length !== expected.length) { + console.log( + 'WebSocketArrayBufferTest FAIL: expected ' + + expected.length + + ' bytes, got ' + + bytes.length, + ); + TestModule.markTestPassed(false); + socket.close(); + return; + } + + for (let i = 0; i < expected.length; i++) { + if (bytes[i] !== expected[i]) { + console.log( + 'WebSocketArrayBufferTest FAIL: byte[' + + i + + '] expected ' + + expected[i] + + ' got ' + + bytes[i], + ); + TestModule.markTestPassed(false); + socket.close(); + return; + } + } + + TestModule.markTestPassed(true); + socket.close(); +}); + +socket.addEventListener('error', () => { + console.log('WebSocketArrayBufferTest FAIL: WebSocket error'); + TestModule.markTestPassed(false); +});