|
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 |
|
10 | | -from roborock.devices.local_channel import LocalChannel |
| 10 | +from roborock.devices.local_channel import LocalChannel, LocalChannelParams |
11 | 11 | from roborock.exceptions import RoborockConnectionException |
12 | 12 | from roborock.protocol import create_local_decoder, create_local_encoder |
| 13 | +from roborock.protocols.v1_protocol import LocalProtocolVersion |
13 | 14 | from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol |
14 | 15 |
|
15 | 16 | TEST_HOST = "192.168.1.100" |
@@ -56,9 +57,18 @@ def setup_mock_loop(mock_transport: Mock) -> Generator[Mock, None, None]: |
56 | 57 |
|
57 | 58 |
|
58 | 59 | @pytest.fixture(name="local_channel") |
59 | | -def setup_local_channel() -> LocalChannel: |
60 | | - """Fixture to set up the local channel for tests.""" |
61 | | - return LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY) |
| 60 | +async def setup_local_channel_with_hello_mock() -> LocalChannel: |
| 61 | + """Fixture to set up the local channel with automatic hello mocking.""" |
| 62 | + channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY) |
| 63 | + |
| 64 | + async def mock_do_hello(local_protocol_version): |
| 65 | + """Mock _do_hello to return successful params without sending actual request.""" |
| 66 | + return LocalChannelParams(local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=54321) |
| 67 | + |
| 68 | + # Replace the _do_hello method |
| 69 | + setattr(channel, "_do_hello", mock_do_hello) |
| 70 | + |
| 71 | + return channel |
62 | 72 |
|
63 | 73 |
|
64 | 74 | @pytest.fixture(name="received_messages") |
@@ -231,3 +241,33 @@ async def test_connection_lost_without_exception( |
231 | 241 | assert local_channel._is_connected is False |
232 | 242 | assert local_channel._transport is None |
233 | 243 | assert "Connection lost to 192.168.1.100" in caplog.text |
| 244 | + |
| 245 | + |
| 246 | +async def test_hello_fallback_to_l01_protocol(mock_loop: Mock, mock_transport: Mock) -> None: |
| 247 | + """Test that when first hello() message fails (V1) but second succeeds (L01), we use L01.""" |
| 248 | + |
| 249 | + # Create a channel without the automatic hello mocking |
| 250 | + channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY) |
| 251 | + |
| 252 | + # Mock _do_hello to fail for V1 but succeed for L01 |
| 253 | + async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None: |
| 254 | + if local_protocol_version == LocalProtocolVersion.V1: |
| 255 | + # First attempt (V1) fails - return None to simulate failure |
| 256 | + return None |
| 257 | + elif local_protocol_version == LocalProtocolVersion.L01: |
| 258 | + # Second attempt (L01) succeeds |
| 259 | + return LocalChannelParams( |
| 260 | + local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=54321 |
| 261 | + ) |
| 262 | + return None |
| 263 | + |
| 264 | + # Replace the _do_hello method |
| 265 | + setattr(channel, "_do_hello", mock_do_hello) |
| 266 | + |
| 267 | + # Connect and verify L01 protocol is used |
| 268 | + await channel.connect() |
| 269 | + |
| 270 | + # Verify that the channel is using L01 protocol |
| 271 | + assert channel._local_protocol_version == LocalProtocolVersion.L01 |
| 272 | + assert channel._ack_nonce == 54321 |
| 273 | + assert channel._is_connected is True |
0 commit comments