Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from roborock.web_api import RoborockApiClient, UserWebApiClient

from .cache import Cache, DeviceCache, NoCache
from .rpc.b01_q7_channel import create_b01_q7_channel
from .rpc.b01_q10_channel import create_b01_q10_channel
from .rpc.v1_channel import create_v1_channel
from .traits import Trait, a01, b01, v1
from .transport.channel import Channel
Expand Down Expand Up @@ -254,13 +256,22 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
trait = a01.create(product, channel)
case DeviceVersion.B01:
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
model_part = product.model.split(".")[-1]
if "ss" in model_part:
b01_q10_channel = create_b01_q10_channel(mqtt_channel)
channel = b01_q10_channel
trait = b01.q10.create(channel)
elif "sc" in model_part:
# Q7 devices start with 'sc' in their model naming.
trait = b01.q7.create(product, device, channel)
b01_q7_channel = create_b01_q7_channel(device, product, mqtt_channel)
channel = b01_q7_channel
trait = b01.q7.create(
product,
device,
rpc_channel=b01_q7_channel,
map_rpc_channel=b01_q7_channel,
)
else:
raise UnsupportedDeviceError(f"Device {device.name} has unsupported B01 model: {product.model}")
case _:
Expand Down
54 changes: 53 additions & 1 deletion roborock/devices/rpc/b01_q10_channel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Thin wrapper around the MQTT channel for Roborock B01 Q10 devices."""

from __future__ import annotations

import logging
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Protocol

from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP
from roborock.devices.transport.channel import Channel
from roborock.devices.transport.mqtt_channel import MqttChannel
from roborock.exceptions import RoborockException
from roborock.protocols.b01_q10_protocol import (
Expand All @@ -12,10 +16,53 @@
decode_message,
encode_mqtt_payload,
)
from roborock.roborock_message import RoborockMessage

_LOGGER = logging.getLogger(__name__)


class Q10RpcChannel(Protocol):
"""Protocol for Q10 RPC channels."""

async def send_command(
self,
command: B01_Q10_DP,
params: ParamsType = None,
) -> None:
"""Send a command on the MQTT channel, without waiting for a response."""
...


class B01Q10Channel(Channel, Q10RpcChannel):
"""Unified B01 Q10 channel wrapping MQTT transport."""

def __init__(self, mqtt_channel: MqttChannel) -> None:
self._mqtt_channel = mqtt_channel

@property
def is_connected(self) -> bool:
return self._mqtt_channel.is_connected

@property
def is_local_connected(self) -> bool:
return False

async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
return await self._mqtt_channel.subscribe(callback)

async def subscribe_stream(self) -> AsyncGenerator[Q10Message, None]:
"""Stream decoded Q10 messages received via MQTT."""
async for msg in stream_decoded_messages(self._mqtt_channel):
yield msg

async def send_command(
self,
command: B01_Q10_DP,
params: ParamsType = None,
) -> None:
await send_command(self._mqtt_channel, command, params)


async def stream_decoded_messages(
mqtt_channel: MqttChannel,
) -> AsyncGenerator[Q10Message, None]:
Expand Down Expand Up @@ -59,3 +106,8 @@ async def send_command(
ex,
)
raise


def create_b01_q10_channel(mqtt_channel: MqttChannel) -> B01Q10Channel:
"""Create a B01Q10Channel wrapping MQTT transport."""
return B01Q10Channel(mqtt_channel)
92 changes: 76 additions & 16 deletions roborock/devices/rpc/b01_q7_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
import json
import logging
from collections.abc import Callable
from typing import TypeAlias, TypeVar
from typing import Protocol, TypeAlias, TypeVar

from roborock.data import HomeDataDevice, HomeDataProduct
from roborock.devices.transport.channel import Channel
from roborock.devices.transport.mqtt_channel import MqttChannel
from roborock.exceptions import RoborockException
from roborock.protocols.b01_q7_protocol import (
B01_Q7_DPS,
B01_VERSION,
CommandType,
MapKey,
ParamsType,
Q7RequestMessage,
create_map_key,
decode_map_payload,
decode_rpc_response,
encode_mqtt_payload,
Expand All @@ -26,6 +32,30 @@
DecodedB01Response: TypeAlias = dict[str, object] | str


class Q7RpcChannel(Protocol):
"""Protocol for Q7 RPC channels."""

async def send_command(
self,
command: CommandType,
params: ParamsType = None,
) -> DecodedB01Response:
"""Send a command and get a decoded response."""
...


class Q7MapRpcChannel(Protocol):
"""Protocol for Q7 map RPC channels."""

async def send_map_command(
self,
command: CommandType,
params: ParamsType = None,
) -> bytes:
"""Send a map command and get decoded bytes."""
...


def _matches_map_response(response_message: RoborockMessage, *, version: bytes | None) -> bytes | None:
"""Return raw map payload bytes for matching MAP_RESPONSE messages."""
if (
Expand Down Expand Up @@ -120,39 +150,55 @@ def find_response(response_message: RoborockMessage) -> DecodedB01Response | Non
raise RoborockException(f"B01 command timed out after {_TIMEOUT}s ({request_message})") from ex
except RoborockException as ex:
_LOGGER.warning(
"Error sending B01 decoded command (%ss): %s",
"Error sending B01 decoded command (%s): %s",
request_message,
ex,
)
raise
except Exception as ex:
_LOGGER.exception(
"Error sending B01 decoded command (%ss): %s",
"Error sending B01 decoded command (%s): %s",
request_message,
ex,
)
raise


class MapRpcChannel:
"""RPC channel for map-related commands on B01/Q7 devices."""
class B01Q7Channel(Channel, Q7RpcChannel, Q7MapRpcChannel):
"""Unified B01 Q7 channel wrapping MQTT transport."""

def __init__(self, mqtt_channel: MqttChannel, map_key: MapKey) -> None:
self._mqtt_channel = mqtt_channel
self._map_key = map_key

async def send_map_command(self, request_message: Q7RequestMessage) -> bytes:
"""Send a map upload command and return decoded SCMap bytes.

This publishes the request and waits for a matching ``MAP_RESPONSE`` message
with the correct protocol version. The raw ``MAP_RESPONSE`` payload bytes are
then decoded/inflated via :func:`decode_map_payload` using this channel's
``map_key``, and the resulting SCMap bytes are returned.

The returned value is the decoded map data bytes suitable for passing to the
map parser library, not the raw MQTT ``MAP_RESPONSE`` payload bytes.
"""
@property
def is_connected(self) -> bool:
return self._mqtt_channel.is_connected

@property
def is_local_connected(self) -> bool:
return False

async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
return await self._mqtt_channel.subscribe(callback)

async def send_command(
self,
command: CommandType,
params: ParamsType = None,
) -> DecodedB01Response:
return await send_decoded_command(
self._mqtt_channel,
Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params),
)

async def send_map_command(
self,
command: CommandType,
params: ParamsType = None,
) -> bytes:
"""Send a map upload command and return decoded SCMap bytes."""
request_message = Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params)
try:
raw_payload = await _send_command(
self._mqtt_channel,
Expand All @@ -163,3 +209,17 @@ async def send_map_command(self, request_message: Q7RequestMessage) -> bytes:
raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex

return decode_map_payload(raw_payload, map_key=self._map_key)


def create_b01_q7_channel(
device: HomeDataDevice,
product: HomeDataProduct,
mqtt_channel: MqttChannel,
) -> B01Q7Channel:
"""Create a B01Q7Channel for the given device."""
if device.sn is None or product.model is None:
raise RoborockException(
f"Device serial number and product model are required (sn: {device.sn}, model: {product.model})"
)
map_key = create_map_key(serial=device.sn, model=product.model)
return B01Q7Channel(mqtt_channel, map_key)
9 changes: 4 additions & 5 deletions roborock/devices/traits/b01/q10/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import logging

from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP
from roborock.devices.rpc.b01_q10_channel import stream_decoded_messages
from roborock.devices.rpc.b01_q10_channel import B01Q10Channel
from roborock.devices.traits import Trait
from roborock.devices.transport.mqtt_channel import MqttChannel
from roborock.map.b01_q10_map_parser import Q10MapPacket, Q10TracePacket
from roborock.protocols.b01_q10_protocol import Q10DpsUpdate, Q10Message

Expand Down Expand Up @@ -78,7 +77,7 @@ class Q10PropertiesApi(Trait):
map: MapContentTrait
"""Trait for fetching the current parsed map (image + rooms)."""

def __init__(self, channel: MqttChannel) -> None:
def __init__(self, channel: B01Q10Channel) -> None:
"""Initialize the B01Props API."""
self._channel = channel
self.command = CommandTrait(channel)
Expand Down Expand Up @@ -127,7 +126,7 @@ async def refresh(self) -> None:

async def _subscribe_loop(self) -> None:
"""Persistent loop dispatching decoded messages to the read-model traits."""
async for message in stream_decoded_messages(self._channel):
async for message in self._channel.subscribe_stream():
self._handle_message(message)

def _handle_message(self, message: Q10Message) -> None:
Expand All @@ -150,6 +149,6 @@ def _handle_message(self, message: Q10Message) -> None:
trait.update_from_dps(message.dps)


def create(channel: MqttChannel) -> Q10PropertiesApi:
def create(channel: B01Q10Channel) -> Q10PropertiesApi:
"""Create traits for B01 devices."""
return Q10PropertiesApi(channel)
11 changes: 5 additions & 6 deletions roborock/devices/traits/b01/q10/command.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any

from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP
from roborock.devices.rpc.b01_q10_channel import send_command
from roborock.devices.transport.mqtt_channel import MqttChannel
from roborock.devices.rpc.b01_q10_channel import Q10RpcChannel
from roborock.protocols.b01_q10_protocol import ParamsType


Expand All @@ -15,9 +14,9 @@ class CommandTrait:
available.
"""

def __init__(self, channel: MqttChannel) -> None:
def __init__(self, rpc_channel: Q10RpcChannel) -> None:
"""Initialize the CommandTrait."""
self._channel = channel
self._rpc_channel = rpc_channel

async def send(self, command: B01_Q10_DP, params: ParamsType = None) -> Any:
"""Send a command to the device.
Expand All @@ -27,6 +26,6 @@ async def send(self, command: B01_Q10_DP, params: ParamsType = None) -> Any:
caller to ensure that any traits affected by the command are refreshed
as needed.
"""
if not self._channel:
if not self._rpc_channel:
raise ValueError("Device trait in invalid state")
return await send_command(self._channel, command, params=params)
return await self._rpc_channel.send_command(command, params=params)
39 changes: 24 additions & 15 deletions roborock/devices/traits/b01/q7/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
SCWindMapping,
WaterLevelMapping,
)
from roborock.devices.rpc.b01_q7_channel import MapRpcChannel, send_decoded_command
from roborock.devices.rpc.b01_q7_channel import Q7MapRpcChannel, Q7RpcChannel
from roborock.devices.traits import Trait
from roborock.devices.transport.mqtt_channel import MqttChannel
from roborock.exceptions import RoborockException
from roborock.protocols.b01_q7_protocol import B01_Q7_DPS, CommandType, ParamsType, Q7RequestMessage, create_map_key
from roborock.protocols.b01_q7_protocol import CommandType, ParamsType
from roborock.roborock_message import RoborockB01Props
from roborock.roborock_typing import RoborockB01Q7Methods

Expand Down Expand Up @@ -53,19 +52,23 @@ class Q7PropertiesApi(Trait):
"""Trait for fetching parsed current map content."""

def __init__(
self, channel: MqttChannel, map_rpc_channel: MapRpcChannel, device: HomeDataDevice, product: HomeDataProduct
self,
rpc_channel: Q7RpcChannel,
map_rpc_channel: Q7MapRpcChannel,
device: HomeDataDevice,
product: HomeDataProduct,
) -> None:
"""Initialize the Q7 API."""
self._channel = channel
self._rpc_channel = rpc_channel
self._map_rpc_channel = map_rpc_channel
self._device = device
self._product = product

if not device.sn or not product.model:
raise ValueError("B01 Q7 map content requires device serial number and product model metadata")

self.clean_summary = CleanSummaryTrait(channel)
self.map = MapTrait(channel)
self.clean_summary = CleanSummaryTrait(rpc_channel)
self.map = MapTrait(rpc_channel)
self.map_content = MapContentTrait(
self._map_rpc_channel,
self.map,
Expand Down Expand Up @@ -199,17 +202,23 @@ async def find_me(self) -> None:

async def send(self, command: CommandType, params: ParamsType) -> Any:
"""Send a command to the device."""
return await send_decoded_command(
self._channel,
Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params),
)
return await self._rpc_channel.send_command(command, params)


def create(product: HomeDataProduct, device: HomeDataDevice, channel: MqttChannel) -> Q7PropertiesApi:
def create(
product: HomeDataProduct,
device: HomeDataDevice,
rpc_channel: Q7RpcChannel,
map_rpc_channel: Q7MapRpcChannel,
) -> Q7PropertiesApi:
"""Create traits for B01 Q7 devices."""
if device.sn is None or product.model is None:
raise RoborockException(
f"Device serial number and product model are required (sn:: {device.sn}, model: {product.model})"
f"Device serial number and product model are required (sn: {device.sn}, model: {product.model})"
)
map_rpc_channel = MapRpcChannel(channel, map_key=create_map_key(serial=device.sn, model=product.model))
return Q7PropertiesApi(channel, device=device, product=product, map_rpc_channel=map_rpc_channel)
return Q7PropertiesApi(
rpc_channel=rpc_channel,
map_rpc_channel=map_rpc_channel,
device=device,
product=product,
)
Loading
Loading