33from asyncio import Lock , TimerHandle , Transport , get_running_loop
44from collections .abc import Callable
55from dataclasses import dataclass
6+ from enum import StrEnum
67
78import async_timeout
89
910from .. import CommandVacuumError , DeviceData , RoborockCommand
1011from ..api import RoborockClient
1112from ..exceptions import RoborockConnectionException , RoborockException , VacuumError
12- from ..protocol import Decoder , Encoder , create_local_decoder , create_local_encoder
13+ from ..protocol import create_local_decoder , create_local_encoder
1314from ..protocols .v1_protocol import RequestMessage
1415from ..roborock_message import RoborockMessage , RoborockMessageProtocol
15- from ..util import RoborockLoggerAdapter
16+ from ..util import RoborockLoggerAdapter , get_next_int
1617from .roborock_client_v1 import CLOUD_REQUIRED , RoborockClientV1
1718
1819_LOGGER = logging .getLogger (__name__ )
1920
2021
21- _HELLO_REQUEST_MESSAGE = RoborockMessage (
22- protocol = RoborockMessageProtocol .HELLO_REQUEST ,
23- seq = 1 ,
24- random = 22 ,
25- )
22+ class LocalProtocolVersion (StrEnum ):
23+ """Supported local protocol versions. Different from vacuum protocol versions."""
2624
27- _PING_REQUEST_MESSAGE = RoborockMessage (
28- protocol = RoborockMessageProtocol .PING_REQUEST ,
29- seq = 2 ,
30- random = 23 ,
31- )
25+ L01 = "L01"
26+ V1 = "1.0"
3227
3328
3429@dataclass
@@ -50,7 +45,12 @@ def connection_lost(self, exc: Exception | None) -> None:
5045class RoborockLocalClientV1 (RoborockClientV1 , RoborockClient ):
5146 """Roborock local client for v1 devices."""
5247
53- def __init__ (self , device_data : DeviceData , queue_timeout : int = 4 ):
48+ def __init__ (
49+ self ,
50+ device_data : DeviceData ,
51+ queue_timeout : int = 4 ,
52+ local_protocol_version : LocalProtocolVersion | None = None ,
53+ ):
5454 """Initialize the Roborock local client."""
5555 if device_data .host is None :
5656 raise RoborockException ("Host is required" )
@@ -63,11 +63,17 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
6363 RoborockClientV1 .__init__ (self , device_data , security_data = None )
6464 RoborockClient .__init__ (self , device_data )
6565 self ._local_protocol = _LocalProtocol (self ._data_received , self ._connection_lost )
66- self ._encoder : Encoder = create_local_encoder (device_data .device .local_key )
67- self ._decoder : Decoder = create_local_decoder (device_data .device .local_key )
66+ self ._local_protocol_version = local_protocol_version
67+ self ._connect_nonce = get_next_int (10000 , 32767 )
68+ self ._ack_nonce : int | None = None
69+ self ._set_encoder_decoder ()
6870 self .queue_timeout = queue_timeout
6971 self ._logger = RoborockLoggerAdapter (device_data .device .name , _LOGGER )
7072
73+ @property
74+ def local_protocol_version (self ) -> LocalProtocolVersion :
75+ return LocalProtocolVersion .V1 if self ._local_protocol_version is None else self ._local_protocol_version
76+
7177 def _data_received (self , message ):
7278 """Called when data is received from the transport."""
7379 parsed_msg = self ._decoder (message )
@@ -121,20 +127,69 @@ async def async_disconnect(self) -> None:
121127 async with self ._mutex :
122128 self ._sync_disconnect ()
123129
124- async def hello (self ):
130+ def _set_encoder_decoder (self ):
131+ """Updates the encoder decoder. These are updated with nonces after the first hello.
132+ Only L01 uses the nonces."""
133+ self ._encoder = create_local_encoder (self .device_info .device .local_key , self ._connect_nonce , self ._ack_nonce )
134+ self ._decoder = create_local_decoder (self .device_info .device .local_key , self ._connect_nonce , self ._ack_nonce )
135+
136+ async def _do_hello (self , local_protocol_version : LocalProtocolVersion ) -> bool :
137+ """Perform the initial handshaking."""
138+ self ._logger .debug (
139+ "Attempting to use the %s protocol for client %s..." ,
140+ local_protocol_version ,
141+ self .device_info .device .duid ,
142+ )
143+ request = RoborockMessage (
144+ protocol = RoborockMessageProtocol .HELLO_REQUEST ,
145+ version = local_protocol_version .encode (),
146+ random = self ._connect_nonce ,
147+ seq = 1 ,
148+ )
125149 try :
126- return await self ._send_message (
127- roborock_message = _HELLO_REQUEST_MESSAGE ,
128- request_id = _HELLO_REQUEST_MESSAGE .seq ,
150+ response = await self ._send_message (
151+ roborock_message = request ,
152+ request_id = request .seq ,
129153 response_protocol = RoborockMessageProtocol .HELLO_RESPONSE ,
130154 )
131- except Exception as e :
132- self ._logger .error (e )
155+ self ._ack_nonce = response .random
156+ self ._set_encoder_decoder ()
157+ self ._local_protocol_version = local_protocol_version
158+
159+ self ._logger .debug (
160+ "Client %s speaks the %s protocol." ,
161+ self .device_info .device .duid ,
162+ local_protocol_version ,
163+ )
164+ return True
165+ except RoborockException as e :
166+ self ._logger .debug (
167+ "Client %s did not respond or does not speak the %s protocol. %s" ,
168+ self .device_info .device .duid ,
169+ local_protocol_version ,
170+ e ,
171+ )
172+ return False
173+
174+ async def hello (self ):
175+ """Send hello to the device to negotiate protocol."""
176+ if self ._local_protocol_version :
177+ # version is forced
178+ if not await self ._do_hello (self ._local_protocol_version ):
179+ raise RoborockException (f"Failed to connect to device with protocol { self ._local_protocol_version } " )
180+ else :
181+ # try 1.0, then L01
182+ if not await self ._do_hello (LocalProtocolVersion .V1 ):
183+ if not await self ._do_hello (LocalProtocolVersion .L01 ):
184+ raise RoborockException ("Failed to connect to device with any known protocol" )
133185
134186 async def ping (self ) -> None :
187+ ping_message = RoborockMessage (
188+ protocol = RoborockMessageProtocol .PING_REQUEST , version = self .local_protocol_version .encode ()
189+ )
135190 await self ._send_message (
136- roborock_message = _PING_REQUEST_MESSAGE ,
137- request_id = _PING_REQUEST_MESSAGE .seq ,
191+ roborock_message = ping_message ,
192+ request_id = ping_message .seq ,
138193 response_protocol = RoborockMessageProtocol .PING_RESPONSE ,
139194 )
140195
@@ -160,7 +215,10 @@ async def _send_command(
160215 if method in CLOUD_REQUIRED :
161216 raise RoborockException (f"Method { method } is not supported over local connection" )
162217 request_message = RequestMessage (method = method , params = params )
163- roborock_message = request_message .encode_message (RoborockMessageProtocol .GENERAL_REQUEST )
218+ roborock_message = request_message .encode_message (
219+ RoborockMessageProtocol .GENERAL_REQUEST ,
220+ version = self .local_protocol_version ,
221+ )
164222 self ._logger .debug ("Building message id %s for method %s" , request_message .request_id , method )
165223 return await self ._send_message (
166224 roborock_message ,
0 commit comments