Skip to content

Commit bff0e9c

Browse files
authored
feat: Implement L01 protocol (#487)
* chore: init try based on Homey logic * fix: some small changes * fix: some small changes * fix: timestamp * fix: some misc bug changes * fix: make sure we are connected on message send * fix: bug fixes for 1.0 * fix: potentially fix ping? * chore: remove debug * fix: add version to ping * fix: remove excluding ping from id check * chore: some potential clean up * chore: comments
1 parent 362ec1d commit bff0e9c

6 files changed

Lines changed: 137 additions & 44 deletions

File tree

roborock/api.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from .roborock_future import RoborockFuture
1919
from .roborock_message import (
2020
RoborockMessage,
21-
RoborockMessageProtocol,
2221
)
2322
from .util import get_next_int
2423

@@ -91,9 +90,7 @@ async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
9190

9291
def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
9392
queue = RoborockFuture(protocol_id)
94-
if request_id in self._waiting_queue and not (
95-
request_id == 2 and protocol_id == RoborockMessageProtocol.PING_REQUEST
96-
):
93+
if request_id in self._waiting_queue:
9794
new_id = get_next_int(10000, 32767)
9895
self._logger.warning(
9996
"Attempting to create a future with an existing id %s (%s)... New id is %s. "

roborock/protocol.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,16 @@ def _encode(self, obj, context, _):
282282
iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25]
283283
decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
284284
return decipher.encrypt(obj)
285+
elif context.version == b"L01":
286+
return Utils.encrypt_gcm_l01(
287+
plaintext=obj,
288+
local_key=context.search("local_key"),
289+
timestamp=context.timestamp,
290+
sequence=context.seq,
291+
nonce=context.random,
292+
connect_nonce=context.search("connect_nonce"),
293+
ack_nonce=context.search("ack_nonce"),
294+
)
285295
token = self.token_func(context)
286296
encrypted = Utils.encrypt_ecb(obj, token)
287297
return encrypted
@@ -297,6 +307,16 @@ def _decode(self, obj, context, _):
297307
iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25]
298308
decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
299309
return decipher.decrypt(obj)
310+
elif context.version == b"L01":
311+
return Utils.decrypt_gcm_l01(
312+
payload=obj,
313+
local_key=context.search("local_key"),
314+
timestamp=context.timestamp,
315+
sequence=context.seq,
316+
nonce=context.random,
317+
connect_nonce=context.search("connect_nonce"),
318+
ack_nonce=context.search("ack_nonce"),
319+
)
300320
token = self.token_func(context)
301321
decrypted = Utils.decrypt_ecb(obj, token)
302322
return decrypted
@@ -321,7 +341,7 @@ class PrefixedStruct(Struct):
321341
def _parse(self, stream, context, path):
322342
subcon1 = Peek(Optional(Bytes(3)))
323343
peek_version = subcon1.parse_stream(stream, **context)
324-
if peek_version not in (b"1.0", b"A01", b"B01"):
344+
if peek_version not in (b"1.0", b"A01", b"B01", b"L01"):
325345
subcon2 = Bytes(4)
326346
subcon2.parse_stream(stream, **context)
327347
return super()._parse(stream, context, path)
@@ -374,10 +394,12 @@ def __init__(self, con: Construct, required_local_key: bool):
374394
self.con = con
375395
self.required_local_key = required_local_key
376396

377-
def parse(self, data: bytes, local_key: str | None = None) -> tuple[list[RoborockMessage], bytes]:
397+
def parse(
398+
self, data: bytes, local_key: str | None = None, connect_nonce: int | None = None, ack_nonce: int | None = None
399+
) -> tuple[list[RoborockMessage], bytes]:
378400
if self.required_local_key and local_key is None:
379401
raise RoborockException("Local key is required")
380-
parsed = self.con.parse(data, local_key=local_key)
402+
parsed = self.con.parse(data, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
381403
parsed_messages = [Container({"message": parsed.message})] if parsed.get("message") else parsed.messages
382404
messages = []
383405
for message in parsed_messages:
@@ -395,7 +417,12 @@ def parse(self, data: bytes, local_key: str | None = None) -> tuple[list[Roboroc
395417
return messages, remaining
396418

397419
def build(
398-
self, roborock_messages: list[RoborockMessage] | RoborockMessage, local_key: str, prefixed: bool = True
420+
self,
421+
roborock_messages: list[RoborockMessage] | RoborockMessage,
422+
local_key: str,
423+
prefixed: bool = True,
424+
connect_nonce: int | None = None,
425+
ack_nonce: int | None = None,
399426
) -> bytes:
400427
if isinstance(roborock_messages, RoborockMessage):
401428
roborock_messages = [roborock_messages]
@@ -416,7 +443,11 @@ def build(
416443
}
417444
)
418445
return self.con.build(
419-
{"messages": [message for message in messages], "remaining": b""}, local_key=local_key, prefixed=prefixed
446+
{"messages": [message for message in messages], "remaining": b""},
447+
local_key=local_key,
448+
prefixed=prefixed,
449+
connect_nonce=connect_nonce,
450+
ack_nonce=ack_nonce,
420451
)
421452

422453

@@ -466,29 +497,31 @@ def encode(messages: RoborockMessage) -> bytes:
466497
return encode
467498

468499

469-
def create_local_decoder(local_key: str) -> Decoder:
500+
def create_local_decoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Decoder:
470501
"""Create a decoder for local API messages."""
471502

472503
# This buffer is used to accumulate bytes until a complete message can be parsed.
473504
# It is defined outside the decode function to maintain state across calls.
474505
buffer: bytes = b""
475506

476-
def decode(bytes: bytes) -> list[RoborockMessage]:
507+
def decode(bytes_data: bytes) -> list[RoborockMessage]:
477508
"""Parse the given data into Roborock messages."""
478509
nonlocal buffer
479-
buffer += bytes
480-
parsed_messages, remaining = MessageParser.parse(buffer, local_key=local_key)
510+
buffer += bytes_data
511+
parsed_messages, remaining = MessageParser.parse(
512+
buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce
513+
)
481514
buffer = remaining
482515
return parsed_messages
483516

484517
return decode
485518

486519

487-
def create_local_encoder(local_key: str) -> Encoder:
520+
def create_local_encoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Encoder:
488521
"""Create an encoder for local API messages."""
489522

490523
def encode(message: RoborockMessage) -> bytes:
491524
"""Called when data is sent to the transport."""
492-
return MessageParser.build(message, local_key=local_key)
525+
return MessageParser.build(message, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
493526

494527
return encode

roborock/protocols/v1_protocol.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,14 @@ class RequestMessage:
6565
request_id: int = field(default_factory=lambda: get_next_int(10000, 32767))
6666

6767
def encode_message(
68-
self,
69-
protocol: RoborockMessageProtocol,
70-
security_data: SecurityData | None = None,
68+
self, protocol: RoborockMessageProtocol, security_data: SecurityData | None = None, version: str = "1.0"
7169
) -> RoborockMessage:
7270
"""Convert the request message to a RoborockMessage."""
7371
return RoborockMessage(
7472
timestamp=self.timestamp,
7573
protocol=protocol,
7674
payload=self._as_payload(security_data=security_data),
75+
version=version.encode(),
7776
)
7877

7978
def _as_payload(self, security_data: SecurityData | None) -> bytes:

roborock/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, Muta
9090
counter_map: dict[tuple[int, int], int] = {}
9191

9292

93-
def get_next_int(min_val: int, max_val: int):
93+
def get_next_int(min_val: int, max_val: int) -> int:
9494
"""Gets a random int in the range, precached to help keep it fast."""
9595
if (min_val, max_val) not in counter_map:
9696
# If we have never seen this range, or if the cache is getting low, make a bunch of preshuffled values.

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,10 +447,16 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
447447
self._logger.debug(
448448
"Received unsolicited map response for request_id %s", map_response.request_id
449449
)
450+
elif data.protocol == RoborockMessageProtocol.GENERAL_RESPONSE and data.payload is None:
451+
# Api will often send blank messages with matching sequences, we can ignore these.
452+
continue
450453
else:
451454
queue = self._waiting_queue.get(data.seq)
452455
if queue:
453-
queue.set_result(data.payload)
456+
if data.protocol == RoborockMessageProtocol.HELLO_RESPONSE:
457+
queue.set_result(data)
458+
else:
459+
queue.set_result(data.payload)
454460
else:
455461
self._logger.debug("Received response for unknown request id %s", data.seq)
456462
except Exception as ex:

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,27 @@
33
from asyncio import Lock, TimerHandle, Transport, get_running_loop
44
from collections.abc import Callable
55
from dataclasses import dataclass
6+
from enum import StrEnum
67

78
import async_timeout
89

910
from .. import CommandVacuumError, DeviceData, RoborockCommand
1011
from ..api import RoborockClient
1112
from ..exceptions import RoborockConnectionException, RoborockException, VacuumError
12-
from ..protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
13+
from ..protocol import create_local_decoder, create_local_encoder
1314
from ..protocols.v1_protocol import RequestMessage
1415
from ..roborock_message import RoborockMessage, RoborockMessageProtocol
15-
from ..util import RoborockLoggerAdapter
16+
from ..util import RoborockLoggerAdapter, get_next_int
1617
from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1
1718

1819
_LOGGER = logging.getLogger(__name__)
1920

2021

21-
_HELLO_REQUEST_MESSAGE = RoborockMessage(
22-
protocol=RoborockMessageProtocol.HELLO_REQUEST,
23-
seq=1,
24-
random=22,
25-
)
22+
class LocalProtocolVersion(StrEnum):
23+
"""Supported local protocol versions. Different from vacuum protocol versions."""
2624

27-
_PING_REQUEST_MESSAGE = RoborockMessage(
28-
protocol=RoborockMessageProtocol.PING_REQUEST,
29-
seq=2,
30-
random=23,
31-
)
25+
L01 = "L01"
26+
V1 = "1.0"
3227

3328

3429
@dataclass
@@ -50,7 +45,12 @@ def connection_lost(self, exc: Exception | None) -> None:
5045
class RoborockLocalClientV1(RoborockClientV1, RoborockClient):
5146
"""Roborock local client for v1 devices."""
5247

53-
def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
48+
def __init__(
49+
self,
50+
device_data: DeviceData,
51+
queue_timeout: int = 4,
52+
local_protocol_version: LocalProtocolVersion | None = None,
53+
):
5454
"""Initialize the Roborock local client."""
5555
if device_data.host is None:
5656
raise RoborockException("Host is required")
@@ -63,11 +63,17 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
6363
RoborockClientV1.__init__(self, device_data, security_data=None)
6464
RoborockClient.__init__(self, device_data)
6565
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
66-
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
67-
self._decoder: Decoder = create_local_decoder(device_data.device.local_key)
66+
self._local_protocol_version = local_protocol_version
67+
self._connect_nonce = get_next_int(10000, 32767)
68+
self._ack_nonce: int | None = None
69+
self._set_encoder_decoder()
6870
self.queue_timeout = queue_timeout
6971
self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER)
7072

73+
@property
74+
def local_protocol_version(self) -> LocalProtocolVersion:
75+
return LocalProtocolVersion.V1 if self._local_protocol_version is None else self._local_protocol_version
76+
7177
def _data_received(self, message):
7278
"""Called when data is received from the transport."""
7379
parsed_msg = self._decoder(message)
@@ -121,20 +127,69 @@ async def async_disconnect(self) -> None:
121127
async with self._mutex:
122128
self._sync_disconnect()
123129

124-
async def hello(self):
130+
def _set_encoder_decoder(self):
131+
"""Updates the encoder decoder. These are updated with nonces after the first hello.
132+
Only L01 uses the nonces."""
133+
self._encoder = create_local_encoder(self.device_info.device.local_key, self._connect_nonce, self._ack_nonce)
134+
self._decoder = create_local_decoder(self.device_info.device.local_key, self._connect_nonce, self._ack_nonce)
135+
136+
async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> bool:
137+
"""Perform the initial handshaking."""
138+
self._logger.debug(
139+
"Attempting to use the %s protocol for client %s...",
140+
local_protocol_version,
141+
self.device_info.device.duid,
142+
)
143+
request = RoborockMessage(
144+
protocol=RoborockMessageProtocol.HELLO_REQUEST,
145+
version=local_protocol_version.encode(),
146+
random=self._connect_nonce,
147+
seq=1,
148+
)
125149
try:
126-
return await self._send_message(
127-
roborock_message=_HELLO_REQUEST_MESSAGE,
128-
request_id=_HELLO_REQUEST_MESSAGE.seq,
150+
response = await self._send_message(
151+
roborock_message=request,
152+
request_id=request.seq,
129153
response_protocol=RoborockMessageProtocol.HELLO_RESPONSE,
130154
)
131-
except Exception as e:
132-
self._logger.error(e)
155+
self._ack_nonce = response.random
156+
self._set_encoder_decoder()
157+
self._local_protocol_version = local_protocol_version
158+
159+
self._logger.debug(
160+
"Client %s speaks the %s protocol.",
161+
self.device_info.device.duid,
162+
local_protocol_version,
163+
)
164+
return True
165+
except RoborockException as e:
166+
self._logger.debug(
167+
"Client %s did not respond or does not speak the %s protocol. %s",
168+
self.device_info.device.duid,
169+
local_protocol_version,
170+
e,
171+
)
172+
return False
173+
174+
async def hello(self):
175+
"""Send hello to the device to negotiate protocol."""
176+
if self._local_protocol_version:
177+
# version is forced
178+
if not await self._do_hello(self._local_protocol_version):
179+
raise RoborockException(f"Failed to connect to device with protocol {self._local_protocol_version}")
180+
else:
181+
# try 1.0, then L01
182+
if not await self._do_hello(LocalProtocolVersion.V1):
183+
if not await self._do_hello(LocalProtocolVersion.L01):
184+
raise RoborockException("Failed to connect to device with any known protocol")
133185

134186
async def ping(self) -> None:
187+
ping_message = RoborockMessage(
188+
protocol=RoborockMessageProtocol.PING_REQUEST, version=self.local_protocol_version.encode()
189+
)
135190
await self._send_message(
136-
roborock_message=_PING_REQUEST_MESSAGE,
137-
request_id=_PING_REQUEST_MESSAGE.seq,
191+
roborock_message=ping_message,
192+
request_id=ping_message.seq,
138193
response_protocol=RoborockMessageProtocol.PING_RESPONSE,
139194
)
140195

@@ -160,7 +215,10 @@ async def _send_command(
160215
if method in CLOUD_REQUIRED:
161216
raise RoborockException(f"Method {method} is not supported over local connection")
162217
request_message = RequestMessage(method=method, params=params)
163-
roborock_message = request_message.encode_message(RoborockMessageProtocol.GENERAL_REQUEST)
218+
roborock_message = request_message.encode_message(
219+
RoborockMessageProtocol.GENERAL_REQUEST,
220+
version=self.local_protocol_version,
221+
)
164222
self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
165223
return await self._send_message(
166224
roborock_message,

0 commit comments

Comments
 (0)