|
5 | 5 | import asyncio |
6 | 6 | import json |
7 | 7 | import logging |
| 8 | +import weakref |
8 | 9 | from typing import Any |
9 | 10 |
|
10 | 11 | from roborock.devices.transport.mqtt_channel import MqttChannel |
|
18 | 19 |
|
19 | 20 | _LOGGER = logging.getLogger(__name__) |
20 | 21 | _TIMEOUT = 10.0 |
| 22 | +_map_command_locks: weakref.WeakKeyDictionary[MqttChannel, asyncio.Lock] = weakref.WeakKeyDictionary() |
| 23 | + |
| 24 | + |
| 25 | +def _get_map_command_lock(mqtt_channel: MqttChannel) -> asyncio.Lock: |
| 26 | + lock = _map_command_locks.get(mqtt_channel) |
| 27 | + if lock is None: |
| 28 | + lock = asyncio.Lock() |
| 29 | + _map_command_locks[mqtt_channel] = lock |
| 30 | + return lock |
21 | 31 |
|
22 | 32 |
|
23 | 33 | async def send_decoded_command( |
@@ -102,49 +112,54 @@ def find_response(response_message: RoborockMessage) -> None: |
102 | 112 |
|
103 | 113 |
|
104 | 114 | async def send_map_command(mqtt_channel: MqttChannel, request_message: Q7RequestMessage) -> bytes: |
105 | | - """Send map upload command and wait for MAP_RESPONSE payload bytes.""" |
| 115 | + """Send map upload command and wait for MAP_RESPONSE payload bytes. |
106 | 116 |
|
107 | | - roborock_message = encode_mqtt_payload(request_message) |
108 | | - future: asyncio.Future[bytes] = asyncio.get_running_loop().create_future() |
| 117 | + Map requests are serialized per channel so concurrent map calls cannot |
| 118 | + cross-wire responses between callers. |
| 119 | + """ |
109 | 120 |
|
110 | | - def find_response(response_message: RoborockMessage) -> None: |
111 | | - if future.done(): |
112 | | - return |
| 121 | + async with _get_map_command_lock(mqtt_channel): |
| 122 | + roborock_message = encode_mqtt_payload(request_message) |
| 123 | + future: asyncio.Future[bytes] = asyncio.get_running_loop().create_future() |
113 | 124 |
|
114 | | - if response_message.protocol == RoborockMessageProtocol.MAP_RESPONSE and response_message.payload: |
115 | | - future.set_result(response_message.payload) |
116 | | - return |
| 125 | + def find_response(response_message: RoborockMessage) -> None: |
| 126 | + if future.done(): |
| 127 | + return |
117 | 128 |
|
118 | | - try: |
119 | | - decoded_dps = decode_rpc_response(response_message) |
120 | | - except RoborockException: |
121 | | - return |
| 129 | + if response_message.protocol == RoborockMessageProtocol.MAP_RESPONSE and response_message.payload: |
| 130 | + future.set_result(response_message.payload) |
| 131 | + return |
122 | 132 |
|
123 | | - for dps_value in decoded_dps.values(): |
124 | | - if not isinstance(dps_value, str): |
125 | | - continue |
126 | 133 | try: |
127 | | - inner = json.loads(dps_value) |
128 | | - except (json.JSONDecodeError, TypeError): |
129 | | - continue |
130 | | - if not isinstance(inner, dict) or inner.get("msgId") != str(request_message.msg_id): |
131 | | - continue |
132 | | - code = inner.get("code", 0) |
133 | | - if code != 0: |
134 | | - future.set_exception(RoborockException(f"B01 command failed with code {code} ({request_message})")) |
| 134 | + decoded_dps = decode_rpc_response(response_message) |
| 135 | + except RoborockException: |
135 | 136 | return |
136 | | - data = inner.get("data") |
137 | | - if isinstance(data, dict) and isinstance(data.get("payload"), str): |
| 137 | + |
| 138 | + for dps_value in decoded_dps.values(): |
| 139 | + if not isinstance(dps_value, str): |
| 140 | + continue |
138 | 141 | try: |
139 | | - future.set_result(bytes.fromhex(data["payload"])) |
140 | | - except ValueError: |
141 | | - pass |
| 142 | + inner = json.loads(dps_value) |
| 143 | + except (json.JSONDecodeError, TypeError): |
| 144 | + continue |
| 145 | + if not isinstance(inner, dict) or inner.get("msgId") != str(request_message.msg_id): |
| 146 | + continue |
| 147 | + code = inner.get("code", 0) |
| 148 | + if code != 0: |
| 149 | + future.set_exception(RoborockException(f"B01 command failed with code {code} ({request_message})")) |
| 150 | + return |
| 151 | + data = inner.get("data") |
| 152 | + if isinstance(data, dict) and isinstance(data.get("payload"), str): |
| 153 | + try: |
| 154 | + future.set_result(bytes.fromhex(data["payload"])) |
| 155 | + except ValueError: |
| 156 | + pass |
142 | 157 |
|
143 | | - unsub = await mqtt_channel.subscribe(find_response) |
144 | | - try: |
145 | | - await mqtt_channel.publish(roborock_message) |
146 | | - return await asyncio.wait_for(future, timeout=_TIMEOUT) |
147 | | - except TimeoutError as ex: |
148 | | - raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex |
149 | | - finally: |
150 | | - unsub() |
| 158 | + unsub = await mqtt_channel.subscribe(find_response) |
| 159 | + try: |
| 160 | + await mqtt_channel.publish(roborock_message) |
| 161 | + return await asyncio.wait_for(future, timeout=_TIMEOUT) |
| 162 | + except TimeoutError as ex: |
| 163 | + raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex |
| 164 | + finally: |
| 165 | + unsub() |
0 commit comments