Skip to content

Commit 12a91bb

Browse files
committed
refactor: decouple B01 (Q7/Q10) protocol layer from transport layer
1 parent 5712f92 commit 12a91bb

23 files changed

Lines changed: 762 additions & 528 deletions

roborock/devices/device_manager.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from roborock.web_api import RoborockApiClient, UserWebApiClient
2727

2828
from .cache import Cache, DeviceCache, NoCache
29+
from .rpc.b01_q7_channel import create_b01_q7_channel
30+
from .rpc.b01_q10_channel import create_b01_q10_channel
2931
from .rpc.v1_channel import create_v1_channel
3032
from .traits import Trait, a01, b01, v1
3133
from .transport.channel import Channel
@@ -254,13 +256,22 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat
254256
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
255257
trait = a01.create(product, channel)
256258
case DeviceVersion.B01:
257-
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
259+
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
258260
model_part = product.model.split(".")[-1]
259261
if "ss" in model_part:
262+
b01_q10_channel = create_b01_q10_channel(mqtt_channel)
263+
channel = b01_q10_channel
260264
trait = b01.q10.create(channel)
261265
elif "sc" in model_part:
262266
# Q7 devices start with 'sc' in their model naming.
263-
trait = b01.q7.create(product, device, channel)
267+
b01_q7_channel = create_b01_q7_channel(device, product, mqtt_channel)
268+
channel = b01_q7_channel
269+
trait = b01.q7.create(
270+
product,
271+
device,
272+
rpc_channel=b01_q7_channel,
273+
map_rpc_channel=b01_q7_channel,
274+
)
264275
else:
265276
raise UnsupportedDeviceError(f"Device {device.name} has unsupported B01 model: {product.model}")
266277
case _:

roborock/devices/rpc/b01_q10_channel.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Thin wrapper around the MQTT channel for Roborock B01 Q10 devices."""
22

3+
from __future__ import annotations
4+
35
import logging
4-
from collections.abc import AsyncGenerator
6+
from collections.abc import AsyncGenerator, Callable
7+
from typing import Protocol
58

69
from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP
10+
from roborock.devices.transport.channel import Channel
711
from roborock.devices.transport.mqtt_channel import MqttChannel
812
from roborock.exceptions import RoborockException
913
from roborock.protocols.b01_q10_protocol import (
@@ -12,10 +16,53 @@
1216
decode_message,
1317
encode_mqtt_payload,
1418
)
19+
from roborock.roborock_message import RoborockMessage
1520

1621
_LOGGER = logging.getLogger(__name__)
1722

1823

24+
class Q10RpcChannel(Protocol):
25+
"""Protocol for Q10 RPC channels."""
26+
27+
async def send_command(
28+
self,
29+
command: B01_Q10_DP,
30+
params: ParamsType = None,
31+
) -> None:
32+
"""Send a command on the MQTT channel, without waiting for a response."""
33+
...
34+
35+
36+
class B01Q10Channel(Channel, Q10RpcChannel):
37+
"""Unified B01 Q10 channel wrapping MQTT transport."""
38+
39+
def __init__(self, mqtt_channel: MqttChannel) -> None:
40+
self._mqtt_channel = mqtt_channel
41+
42+
@property
43+
def is_connected(self) -> bool:
44+
return self._mqtt_channel.is_connected
45+
46+
@property
47+
def is_local_connected(self) -> bool:
48+
return False
49+
50+
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
51+
return await self._mqtt_channel.subscribe(callback)
52+
53+
async def subscribe_stream(self) -> AsyncGenerator[Q10Message, None]:
54+
"""Stream decoded Q10 messages received via MQTT."""
55+
async for msg in stream_decoded_messages(self._mqtt_channel):
56+
yield msg
57+
58+
async def send_command(
59+
self,
60+
command: B01_Q10_DP,
61+
params: ParamsType = None,
62+
) -> None:
63+
await send_command(self._mqtt_channel, command, params)
64+
65+
1966
async def stream_decoded_messages(
2067
mqtt_channel: MqttChannel,
2168
) -> AsyncGenerator[Q10Message, None]:
@@ -59,3 +106,8 @@ async def send_command(
59106
ex,
60107
)
61108
raise
109+
110+
111+
def create_b01_q10_channel(mqtt_channel: MqttChannel) -> B01Q10Channel:
112+
"""Create a B01Q10Channel wrapping MQTT transport."""
113+
return B01Q10Channel(mqtt_channel)

roborock/devices/rpc/b01_q7_channel.py

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66
import json
77
import logging
88
from collections.abc import Callable
9-
from typing import TypeAlias, TypeVar
9+
from typing import Protocol, TypeAlias, TypeVar
1010

11+
from roborock.data import HomeDataDevice, HomeDataProduct
12+
from roborock.devices.transport.channel import Channel
1113
from roborock.devices.transport.mqtt_channel import MqttChannel
1214
from roborock.exceptions import RoborockException
1315
from roborock.protocols.b01_q7_protocol import (
16+
B01_Q7_DPS,
1417
B01_VERSION,
18+
CommandType,
1519
MapKey,
20+
ParamsType,
1621
Q7RequestMessage,
22+
create_map_key,
1723
decode_map_payload,
1824
decode_rpc_response,
1925
encode_mqtt_payload,
@@ -26,6 +32,30 @@
2632
DecodedB01Response: TypeAlias = dict[str, object] | str
2733

2834

35+
class Q7RpcChannel(Protocol):
36+
"""Protocol for Q7 RPC channels."""
37+
38+
async def send_command(
39+
self,
40+
command: CommandType,
41+
params: ParamsType = None,
42+
) -> DecodedB01Response:
43+
"""Send a command and get a decoded response."""
44+
...
45+
46+
47+
class Q7MapRpcChannel(Protocol):
48+
"""Protocol for Q7 map RPC channels."""
49+
50+
async def send_map_command(
51+
self,
52+
command: CommandType,
53+
params: ParamsType = None,
54+
) -> bytes:
55+
"""Send a map command and get decoded bytes."""
56+
...
57+
58+
2959
def _matches_map_response(response_message: RoborockMessage, *, version: bytes | None) -> bytes | None:
3060
"""Return raw map payload bytes for matching MAP_RESPONSE messages."""
3161
if (
@@ -120,39 +150,55 @@ def find_response(response_message: RoborockMessage) -> DecodedB01Response | Non
120150
raise RoborockException(f"B01 command timed out after {_TIMEOUT}s ({request_message})") from ex
121151
except RoborockException as ex:
122152
_LOGGER.warning(
123-
"Error sending B01 decoded command (%ss): %s",
153+
"Error sending B01 decoded command (%s): %s",
124154
request_message,
125155
ex,
126156
)
127157
raise
128158
except Exception as ex:
129159
_LOGGER.exception(
130-
"Error sending B01 decoded command (%ss): %s",
160+
"Error sending B01 decoded command (%s): %s",
131161
request_message,
132162
ex,
133163
)
134164
raise
135165

136166

137-
class MapRpcChannel:
138-
"""RPC channel for map-related commands on B01/Q7 devices."""
167+
class B01Q7Channel(Channel, Q7RpcChannel, Q7MapRpcChannel):
168+
"""Unified B01 Q7 channel wrapping MQTT transport."""
139169

140170
def __init__(self, mqtt_channel: MqttChannel, map_key: MapKey) -> None:
141171
self._mqtt_channel = mqtt_channel
142172
self._map_key = map_key
143173

144-
async def send_map_command(self, request_message: Q7RequestMessage) -> bytes:
145-
"""Send a map upload command and return decoded SCMap bytes.
146-
147-
This publishes the request and waits for a matching ``MAP_RESPONSE`` message
148-
with the correct protocol version. The raw ``MAP_RESPONSE`` payload bytes are
149-
then decoded/inflated via :func:`decode_map_payload` using this channel's
150-
``map_key``, and the resulting SCMap bytes are returned.
151-
152-
The returned value is the decoded map data bytes suitable for passing to the
153-
map parser library, not the raw MQTT ``MAP_RESPONSE`` payload bytes.
154-
"""
174+
@property
175+
def is_connected(self) -> bool:
176+
return self._mqtt_channel.is_connected
177+
178+
@property
179+
def is_local_connected(self) -> bool:
180+
return False
181+
182+
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
183+
return await self._mqtt_channel.subscribe(callback)
184+
185+
async def send_command(
186+
self,
187+
command: CommandType,
188+
params: ParamsType = None,
189+
) -> DecodedB01Response:
190+
return await send_decoded_command(
191+
self._mqtt_channel,
192+
Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params),
193+
)
155194

195+
async def send_map_command(
196+
self,
197+
command: CommandType,
198+
params: ParamsType = None,
199+
) -> bytes:
200+
"""Send a map upload command and return decoded SCMap bytes."""
201+
request_message = Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params)
156202
try:
157203
raw_payload = await _send_command(
158204
self._mqtt_channel,
@@ -163,3 +209,17 @@ async def send_map_command(self, request_message: Q7RequestMessage) -> bytes:
163209
raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex
164210

165211
return decode_map_payload(raw_payload, map_key=self._map_key)
212+
213+
214+
def create_b01_q7_channel(
215+
device: HomeDataDevice,
216+
product: HomeDataProduct,
217+
mqtt_channel: MqttChannel,
218+
) -> B01Q7Channel:
219+
"""Create a B01Q7Channel for the given device."""
220+
if device.sn is None or product.model is None:
221+
raise RoborockException(
222+
f"Device serial number and product model are required (sn: {device.sn}, model: {product.model})"
223+
)
224+
map_key = create_map_key(serial=device.sn, model=product.model)
225+
return B01Q7Channel(mqtt_channel, map_key)

roborock/devices/traits/b01/q10/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import logging
55

66
from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP
7-
from roborock.devices.rpc.b01_q10_channel import stream_decoded_messages
7+
from roborock.devices.rpc.b01_q10_channel import B01Q10Channel
88
from roborock.devices.traits import Trait
9-
from roborock.devices.transport.mqtt_channel import MqttChannel
109
from roborock.map.b01_q10_map_parser import Q10MapPacket, Q10TracePacket
1110
from roborock.protocols.b01_q10_protocol import Q10DpsUpdate, Q10Message
1211

@@ -78,7 +77,7 @@ class Q10PropertiesApi(Trait):
7877
map: MapContentTrait
7978
"""Trait for fetching the current parsed map (image + rooms)."""
8079

81-
def __init__(self, channel: MqttChannel) -> None:
80+
def __init__(self, channel: B01Q10Channel) -> None:
8281
"""Initialize the B01Props API."""
8382
self._channel = channel
8483
self.command = CommandTrait(channel)
@@ -127,7 +126,7 @@ async def refresh(self) -> None:
127126

128127
async def _subscribe_loop(self) -> None:
129128
"""Persistent loop dispatching decoded messages to the read-model traits."""
130-
async for message in stream_decoded_messages(self._channel):
129+
async for message in self._channel.subscribe_stream():
131130
self._handle_message(message)
132131

133132
def _handle_message(self, message: Q10Message) -> None:
@@ -150,6 +149,6 @@ def _handle_message(self, message: Q10Message) -> None:
150149
trait.update_from_dps(message.dps)
151150

152151

153-
def create(channel: MqttChannel) -> Q10PropertiesApi:
152+
def create(channel: B01Q10Channel) -> Q10PropertiesApi:
154153
"""Create traits for B01 devices."""
155154
return Q10PropertiesApi(channel)

roborock/devices/traits/b01/q10/command.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Any
22

33
from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP
4-
from roborock.devices.rpc.b01_q10_channel import send_command
5-
from roborock.devices.transport.mqtt_channel import MqttChannel
4+
from roborock.devices.rpc.b01_q10_channel import Q10RpcChannel
65
from roborock.protocols.b01_q10_protocol import ParamsType
76

87

@@ -15,9 +14,9 @@ class CommandTrait:
1514
available.
1615
"""
1716

18-
def __init__(self, channel: MqttChannel) -> None:
17+
def __init__(self, rpc_channel: Q10RpcChannel) -> None:
1918
"""Initialize the CommandTrait."""
20-
self._channel = channel
19+
self._rpc_channel = rpc_channel
2120

2221
async def send(self, command: B01_Q10_DP, params: ParamsType = None) -> Any:
2322
"""Send a command to the device.
@@ -27,6 +26,6 @@ async def send(self, command: B01_Q10_DP, params: ParamsType = None) -> Any:
2726
caller to ensure that any traits affected by the command are refreshed
2827
as needed.
2928
"""
30-
if not self._channel:
29+
if not self._rpc_channel:
3130
raise ValueError("Device trait in invalid state")
32-
return await send_command(self._channel, command, params=params)
31+
return await self._rpc_channel.send_command(command, params=params)

roborock/devices/traits/b01/q7/__init__.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
SCWindMapping,
1919
WaterLevelMapping,
2020
)
21-
from roborock.devices.rpc.b01_q7_channel import MapRpcChannel, send_decoded_command
21+
from roborock.devices.rpc.b01_q7_channel import Q7MapRpcChannel, Q7RpcChannel
2222
from roborock.devices.traits import Trait
23-
from roborock.devices.transport.mqtt_channel import MqttChannel
2423
from roborock.exceptions import RoborockException
25-
from roborock.protocols.b01_q7_protocol import B01_Q7_DPS, CommandType, ParamsType, Q7RequestMessage, create_map_key
24+
from roborock.protocols.b01_q7_protocol import CommandType, ParamsType
2625
from roborock.roborock_message import RoborockB01Props
2726
from roborock.roborock_typing import RoborockB01Q7Methods
2827

@@ -53,19 +52,23 @@ class Q7PropertiesApi(Trait):
5352
"""Trait for fetching parsed current map content."""
5453

5554
def __init__(
56-
self, channel: MqttChannel, map_rpc_channel: MapRpcChannel, device: HomeDataDevice, product: HomeDataProduct
55+
self,
56+
rpc_channel: Q7RpcChannel,
57+
map_rpc_channel: Q7MapRpcChannel,
58+
device: HomeDataDevice,
59+
product: HomeDataProduct,
5760
) -> None:
5861
"""Initialize the Q7 API."""
59-
self._channel = channel
62+
self._rpc_channel = rpc_channel
6063
self._map_rpc_channel = map_rpc_channel
6164
self._device = device
6265
self._product = product
6366

6467
if not device.sn or not product.model:
6568
raise ValueError("B01 Q7 map content requires device serial number and product model metadata")
6669

67-
self.clean_summary = CleanSummaryTrait(channel)
68-
self.map = MapTrait(channel)
70+
self.clean_summary = CleanSummaryTrait(rpc_channel)
71+
self.map = MapTrait(rpc_channel)
6972
self.map_content = MapContentTrait(
7073
self._map_rpc_channel,
7174
self.map,
@@ -199,17 +202,23 @@ async def find_me(self) -> None:
199202

200203
async def send(self, command: CommandType, params: ParamsType) -> Any:
201204
"""Send a command to the device."""
202-
return await send_decoded_command(
203-
self._channel,
204-
Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params),
205-
)
205+
return await self._rpc_channel.send_command(command, params)
206206

207207

208-
def create(product: HomeDataProduct, device: HomeDataDevice, channel: MqttChannel) -> Q7PropertiesApi:
208+
def create(
209+
product: HomeDataProduct,
210+
device: HomeDataDevice,
211+
rpc_channel: Q7RpcChannel,
212+
map_rpc_channel: Q7MapRpcChannel,
213+
) -> Q7PropertiesApi:
209214
"""Create traits for B01 Q7 devices."""
210215
if device.sn is None or product.model is None:
211216
raise RoborockException(
212-
f"Device serial number and product model are required (sn:: {device.sn}, model: {product.model})"
217+
f"Device serial number and product model are required (sn: {device.sn}, model: {product.model})"
213218
)
214-
map_rpc_channel = MapRpcChannel(channel, map_key=create_map_key(serial=device.sn, model=product.model))
215-
return Q7PropertiesApi(channel, device=device, product=product, map_rpc_channel=map_rpc_channel)
219+
return Q7PropertiesApi(
220+
rpc_channel=rpc_channel,
221+
map_rpc_channel=map_rpc_channel,
222+
device=device,
223+
product=product,
224+
)

0 commit comments

Comments
 (0)