Skip to content

Commit aea98ed

Browse files
committed
🦎 roborock: address review items for q7 map + segment support
1 parent b4fb0d4 commit aea98ed

File tree

9 files changed

+124
-53
lines changed

9 files changed

+124
-53
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929
"pyrate-limiter>=3.7.0,<4",
3030
"aiomqtt>=2.5.0,<3",
3131
"click-shell~=2.1",
32+
"Pillow>=10,<12",
3233
]
3334

3435
[project.urls]

roborock/devices/rpc/b01_q7_channel.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import json
77
import logging
8+
import weakref
89
from typing import Any
910

1011
from roborock.devices.transport.mqtt_channel import MqttChannel
@@ -18,6 +19,15 @@
1819

1920
_LOGGER = logging.getLogger(__name__)
2021
_TIMEOUT = 10.0
22+
_map_command_locks: weakref.WeakKeyDictionary[MqttChannel, asyncio.Lock] = weakref.WeakKeyDictionary()
23+
24+
25+
def _get_map_command_lock(mqtt_channel: MqttChannel) -> asyncio.Lock:
26+
lock = _map_command_locks.get(mqtt_channel)
27+
if lock is None:
28+
lock = asyncio.Lock()
29+
_map_command_locks[mqtt_channel] = lock
30+
return lock
2131

2232

2333
async def send_decoded_command(
@@ -102,49 +112,54 @@ def find_response(response_message: RoborockMessage) -> None:
102112

103113

104114
async def send_map_command(mqtt_channel: MqttChannel, request_message: Q7RequestMessage) -> bytes:
105-
"""Send map upload command and wait for MAP_RESPONSE payload bytes."""
115+
"""Send map upload command and wait for MAP_RESPONSE payload bytes.
106116
107-
roborock_message = encode_mqtt_payload(request_message)
108-
future: asyncio.Future[bytes] = asyncio.get_running_loop().create_future()
117+
Map requests are serialized per channel so concurrent map calls cannot
118+
cross-wire responses between callers.
119+
"""
109120

110-
def find_response(response_message: RoborockMessage) -> None:
111-
if future.done():
112-
return
121+
async with _get_map_command_lock(mqtt_channel):
122+
roborock_message = encode_mqtt_payload(request_message)
123+
future: asyncio.Future[bytes] = asyncio.get_running_loop().create_future()
113124

114-
if response_message.protocol == RoborockMessageProtocol.MAP_RESPONSE and response_message.payload:
115-
future.set_result(response_message.payload)
116-
return
125+
def find_response(response_message: RoborockMessage) -> None:
126+
if future.done():
127+
return
117128

118-
try:
119-
decoded_dps = decode_rpc_response(response_message)
120-
except RoborockException:
121-
return
129+
if response_message.protocol == RoborockMessageProtocol.MAP_RESPONSE and response_message.payload:
130+
future.set_result(response_message.payload)
131+
return
122132

123-
for dps_value in decoded_dps.values():
124-
if not isinstance(dps_value, str):
125-
continue
126133
try:
127-
inner = json.loads(dps_value)
128-
except (json.JSONDecodeError, TypeError):
129-
continue
130-
if not isinstance(inner, dict) or inner.get("msgId") != str(request_message.msg_id):
131-
continue
132-
code = inner.get("code", 0)
133-
if code != 0:
134-
future.set_exception(RoborockException(f"B01 command failed with code {code} ({request_message})"))
134+
decoded_dps = decode_rpc_response(response_message)
135+
except RoborockException:
135136
return
136-
data = inner.get("data")
137-
if isinstance(data, dict) and isinstance(data.get("payload"), str):
137+
138+
for dps_value in decoded_dps.values():
139+
if not isinstance(dps_value, str):
140+
continue
138141
try:
139-
future.set_result(bytes.fromhex(data["payload"]))
140-
except ValueError:
141-
pass
142+
inner = json.loads(dps_value)
143+
except (json.JSONDecodeError, TypeError):
144+
continue
145+
if not isinstance(inner, dict) or inner.get("msgId") != str(request_message.msg_id):
146+
continue
147+
code = inner.get("code", 0)
148+
if code != 0:
149+
future.set_exception(RoborockException(f"B01 command failed with code {code} ({request_message})"))
150+
return
151+
data = inner.get("data")
152+
if isinstance(data, dict) and isinstance(data.get("payload"), str):
153+
try:
154+
future.set_result(bytes.fromhex(data["payload"]))
155+
except ValueError:
156+
pass
142157

143-
unsub = await mqtt_channel.subscribe(find_response)
144-
try:
145-
await mqtt_channel.publish(roborock_message)
146-
return await asyncio.wait_for(future, timeout=_TIMEOUT)
147-
except TimeoutError as ex:
148-
raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex
149-
finally:
150-
unsub()
158+
unsub = await mqtt_channel.subscribe(find_response)
159+
try:
160+
await mqtt_channel.publish(roborock_message)
161+
return await asyncio.wait_for(future, timeout=_TIMEOUT)
162+
except TimeoutError as ex:
163+
raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex
164+
finally:
165+
unsub()

roborock/devices/traits/b01/q7/__init__.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,24 @@ class Q7PropertiesApi(Trait):
3535

3636
clean_summary: CleanSummaryTrait
3737
"""Trait for clean records / clean summary (Q7 `service.get_record_list`)."""
38-
map_content: Q7MapContentTrait
39-
40-
def __init__(self, channel: MqttChannel, *, local_key: str, serial: str, model: str) -> None:
38+
map_content: Q7MapContentTrait | None
39+
40+
def __init__(
41+
self,
42+
channel: MqttChannel,
43+
*,
44+
local_key: str | None = None,
45+
serial: str | None = None,
46+
model: str | None = None,
47+
) -> None:
4148
"""Initialize the B01Props API."""
4249
self._channel = channel
4350
self.clean_summary = CleanSummaryTrait(channel)
44-
self.map_content = Q7MapContentTrait(channel, local_key=local_key, serial=serial, model=model)
51+
if local_key and serial and model:
52+
self.map_content = Q7MapContentTrait(channel, local_key=local_key, serial=serial, model=model)
53+
else:
54+
# Keep backwards compatibility for direct callers that only use command/query traits.
55+
self.map_content = None
4556

4657
async def query_values(self, props: list[RoborockB01Props]) -> B01Props | None:
4758
"""Query the device for the values of the given Q7 properties."""
@@ -91,14 +102,14 @@ async def start_clean(self) -> None:
91102
},
92103
)
93104

94-
async def clean_segments(self, room_ids: list[int]) -> None:
95-
"""Start segment/room cleaning for the given room ids."""
105+
async def clean_segments(self, segment_ids: list[int]) -> None:
106+
"""Start segment cleaning for the given ids (Q7 uses room ids)."""
96107
await self.send(
97108
command=RoborockB01Q7Methods.SET_ROOM_CLEAN,
98109
params={
99110
"clean_type": CleanTaskTypeMapping.ROOM.code,
100111
"ctrl_value": SCDeviceCleanParam.START.code,
101-
"room_ids": room_ids,
112+
"room_ids": segment_ids,
102113
},
103114
)
104115

@@ -146,6 +157,12 @@ async def send(self, command: CommandType, params: ParamsType) -> Any:
146157
)
147158

148159

149-
def create(channel: MqttChannel, *, local_key: str, serial: str, model: str) -> Q7PropertiesApi:
160+
def create(
161+
channel: MqttChannel,
162+
*,
163+
local_key: str | None = None,
164+
serial: str | None = None,
165+
model: str | None = None,
166+
) -> Q7PropertiesApi:
150167
"""Create traits for B01 Q7 devices."""
151168
return Q7PropertiesApi(channel, local_key=local_key, serial=serial, model=model)

roborock/devices/traits/b01/q7/map_content.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
from roborock.devices.traits.v1.map_content import MapContent
1111
from roborock.devices.transport.mqtt_channel import MqttChannel
1212
from roborock.exceptions import RoborockException
13-
from roborock.map.b01_map_parser import (
14-
decode_b01_map_payload,
15-
parse_scmap_payload,
16-
render_map_png,
17-
)
13+
from roborock.map.b01_map_parser import decode_b01_map_payload, parse_scmap_payload, render_map_png
1814
from roborock.protocols.b01_q7_protocol import Q7RequestMessage
1915
from roborock.roborock_typing import RoborockB01Q7Methods
2016

@@ -23,6 +19,8 @@
2319
class B01MapContent(MapContent):
2420
"""B01 map content wrapper."""
2521

22+
rooms: dict[int, str] | None = None
23+
2624

2725
def _extract_current_map_id(map_list_response: dict[str, Any] | None) -> int | None:
2826
if not isinstance(map_list_response, dict):
@@ -77,5 +75,6 @@ async def refresh(self) -> B01MapContent:
7775
parsed = parse_scmap_payload(inflated)
7876
self.raw_api_response = raw_payload
7977
self.map_data = None
78+
self.rooms = parsed.rooms
8079
self.image_content = render_map_png(parsed)
8180
return self

roborock/map/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Module for Roborock map related data classes."""
1+
"""Utilities and data classes for working with Roborock maps."""
22

33
from .b01_map_parser import B01MapData, decode_b01_map_payload, parse_scmap_payload, render_map_png
44
from .map_parser import MapParserConfig, ParsedMapData

roborock/map/b01_map_parser.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class B01MapData:
2424
size_x: int
2525
size_y: int
2626
map_data: bytes
27+
rooms: dict[int, str] | None = None
2728

2829

2930
def _read_varint(buf: bytes, idx: int) -> tuple[int, int]:
@@ -72,12 +73,36 @@ def _parse_map_data_info(blob: bytes) -> bytes:
7273
raise RoborockException("B01 map payload missing mapData")
7374

7475

76+
def _parse_room_data_info(blob: bytes) -> tuple[int | None, str | None]:
77+
room_id: int | None = None
78+
room_name: str | None = None
79+
idx = 0
80+
while idx < len(blob):
81+
key, idx = _read_varint(blob, idx)
82+
field_no = key >> 3
83+
wire = key & 0x07
84+
if wire == 0:
85+
value, idx = _read_varint(blob, idx)
86+
if field_no == 1:
87+
room_id = int(value)
88+
elif wire == 2:
89+
value, idx = _read_len_delimited(blob, idx)
90+
if field_no == 2:
91+
room_name = value.decode("utf-8", errors="replace")
92+
elif wire == 5:
93+
idx += 4
94+
else:
95+
raise RoborockException(f"Unsupported wire type {wire} in B01 room data info")
96+
return room_id, room_name
97+
98+
7599
def parse_scmap_payload(payload: bytes) -> B01MapData:
76-
"""Parse SCMap protobuf payload and extract occupancy grid bytes."""
100+
"""Parse SCMap protobuf payload and extract occupancy grid bytes and room names."""
77101

78102
size_x = 0
79103
size_y = 0
80104
grid = b""
105+
rooms: dict[int, str] = {}
81106
idx = 0
82107
while idx < len(payload):
83108
key, idx = _read_varint(payload, idx)
@@ -112,12 +137,16 @@ def parse_scmap_payload(payload: bytes) -> B01MapData:
112137
raise RoborockException(f"Unsupported wire type {hwire} in B01 map header")
113138
elif field_no == 4: # mapDataInfo
114139
grid = _parse_map_data_info(value)
140+
elif field_no == 12: # roomDataInfo (repeated)
141+
room_id, room_name = _parse_room_data_info(value)
142+
if room_id is not None:
143+
rooms[room_id] = room_name or f"Room {room_id}"
115144

116145
if not size_x or not size_y or not grid:
117146
raise RoborockException("Failed to parse B01 map header/grid")
118147
if len(grid) < size_x * size_y:
119148
raise RoborockException("B01 map data shorter than expected dimensions")
120-
return B01MapData(size_x=size_x, size_y=size_y, map_data=grid)
149+
return B01MapData(size_x=size_x, size_y=size_y, map_data=grid, rooms=rooms or None)
121150

122151

123152
def _derive_b01_iv(iv_seed: int) -> bytes:

tests/devices/traits/b01/q7/test_init.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,9 @@ async def test_q7_map_content_refresh_errors_without_map_list(
378378

379379
with pytest.raises(RoborockException, match="Unable to determine map_id"):
380380
await q7_api.map_content.refresh()
381+
382+
383+
async def test_q7_api_constructor_backwards_compatible_without_map_context(fake_channel: FakeChannel):
384+
"""Direct API construction without map context should still work."""
385+
api = Q7PropertiesApi(fake_channel) # type: ignore[arg-type]
386+
assert api.map_content is None

tests/map/test_b01_map_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def test_parse_scmap_payload_fixture() -> None:
2020
assert parsed.size_x == 340
2121
assert parsed.size_y == 300
2222
assert len(parsed.map_data) >= parsed.size_x * parsed.size_y
23+
assert parsed.rooms is not None
24+
assert parsed.rooms.get(10) == "room1"
2325

2426

2527
def test_render_map_png_fixture() -> None:

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)