diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index cf7d2e0797..87199cde71 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -141,11 +141,14 @@ async def exec( ) -> dict[str, Any]: def _run() -> dict[str, Any]: try: + # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit + # Executes the current interpreter with a fixed argv list and shell=False. result = subprocess.run( - [os.environ.get("PYTHON", sys.executable), "-c", code], + [sys.executable, "-c", code], timeout=timeout, capture_output=True, text=True, + shell=False, ) stdout = "" if silent else result.stdout stderr = result.stderr if result.returncode != 0 else "" diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index b326ebb449..bbf65da7fa 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -164,7 +164,7 @@ def insert_llm_metrics(self, metrics: dict) -> None: def get_base_stats(self, offset_sec: int = 86400) -> Stats: """获取 offset_sec 秒前到现在的基础统计数据""" - where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" + min_timestamp = int(time.time()) - offset_sec try: c = self.conn.cursor() @@ -174,8 +174,9 @@ def get_base_stats(self, offset_sec: int = 86400) -> Stats: c.execute( """ SELECT * FROM platform - """ - + where_clause, + WHERE timestamp >= :min_timestamp + """, + {"min_timestamp": min_timestamp}, ) platform = [] @@ -203,7 +204,7 @@ def get_total_message_count(self) -> int: def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: """获取 offset_sec 秒前到现在的基础统计数据(合并)""" - where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" + min_timestamp = int(time.time()) - offset_sec try: c = self.conn.cursor() @@ -213,9 +214,10 @@ def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: c.execute( """ SELECT name, SUM(count), timestamp FROM platform - """ - + where_clause - + " GROUP BY name", + WHERE timestamp >= :min_timestamp + GROUP BY name + """, + {"min_timestamp": min_timestamp}, ) platform = [] @@ -403,14 +405,15 @@ def get_filtered_conversations( try: # 构建查询条件 where_clauses = [] - params = [] + params: dict[str, Any] = {} # 平台筛选 if platforms and len(platforms) > 0: platform_conditions = [] - for platform in platforms: - platform_conditions.append("user_id LIKE ?") - params.append(f"{platform}:%") + for index, platform in enumerate(platforms): + param_name = f"platform_{index}" + platform_conditions.append(f"user_id LIKE :{param_name}") + params[param_name] = f"{platform}:%" if platform_conditions: where_clauses.append(f"({' OR '.join(platform_conditions)})") @@ -418,9 +421,10 @@ def get_filtered_conversations( # 消息类型筛选 if message_types and len(message_types) > 0: message_type_conditions = [] - for msg_type in message_types: - message_type_conditions.append("user_id LIKE ?") - params.append(f"%:{msg_type}:%") + for index, msg_type in enumerate(message_types): + param_name = f"message_type_{index}" + message_type_conditions.append(f"user_id LIKE :{param_name}") + params[param_name] = f"%:{msg_type}:%" if message_type_conditions: where_clauses.append(f"({' OR '.join(message_type_conditions)})") @@ -429,28 +433,32 @@ def get_filtered_conversations( if search_query: search_query = search_query.encode("unicode_escape").decode("utf-8") where_clauses.append( - "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)", + "(" + "title LIKE :search_query OR user_id LIKE :search_query OR " + "cid LIKE :search_query OR history LIKE :search_query" + ")", ) - search_param = f"%{search_query}%" - params.extend([search_param, search_param, search_param, search_param]) + params["search_query"] = f"%{search_query}%" # 排除特定用户ID if exclude_ids and len(exclude_ids) > 0: - for exclude_id in exclude_ids: - where_clauses.append("user_id NOT LIKE ?") - params.append(f"{exclude_id}%") + for index, exclude_id in enumerate(exclude_ids): + param_name = f"exclude_id_{index}" + where_clauses.append(f"user_id NOT LIKE :{param_name}") + params[param_name] = f"{exclude_id}%" # 排除特定平台 if exclude_platforms and len(exclude_platforms) > 0: - for exclude_platform in exclude_platforms: - where_clauses.append("user_id NOT LIKE ?") - params.append(f"{exclude_platform}:%") + for index, exclude_platform in enumerate(exclude_platforms): + param_name = f"exclude_platform_{index}" + where_clauses.append(f"user_id NOT LIKE :{param_name}") + params[param_name] = f"{exclude_platform}:%" # 构建完整的 WHERE 子句 where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" # 构建计数查询 - count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}" + count_sql = "SELECT COUNT(*) FROM webchat_conversation" + where_sql # 获取总记录数 c.execute(count_sql, params) @@ -460,14 +468,14 @@ def get_filtered_conversations( offset = (page - 1) * page_size # 构建分页数据查询 - data_sql = f""" - SELECT user_id, cid, created_at, updated_at, title, persona_id - FROM webchat_conversation - {where_sql} - ORDER BY updated_at DESC - LIMIT ? OFFSET ? - """ - query_params = params + [page_size, offset] + data_sql = ( + "SELECT user_id, cid, created_at, updated_at, title, persona_id\n" + "FROM webchat_conversation" + f"{where_sql}\n" + "ORDER BY updated_at DESC\n" + "LIMIT :page_size OFFSET :offset" + ) + query_params = {**params, "page_size": page_size, "offset": offset} # 获取分页数据 c.execute(data_sql, query_params) diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 3e5eb9a90e..d43eae16a5 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -15,6 +15,7 @@ from astrbot.api import logger +from ..websocket_security import to_websocket_url from .misskey_utils import FileIDExtractor # Constants @@ -56,10 +57,7 @@ def __init__(self, instance_url: str, access_token: str) -> None: async def connect(self) -> bool: try: - ws_url = self.instance_url.replace("https://", "wss://").replace( - "http://", - "ws://", - ) + ws_url = to_websocket_url(self.instance_url, label="Misskey instance URL") ws_url += f"/streaming?i={self.access_token}" self.websocket = await websockets.connect( diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index 5c2f7a37f3..6458a53658 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -27,6 +27,8 @@ ) from astrbot.core.platform.astr_message_event import MessageSession +from ..websocket_security import require_secure_transport_url + @register_platform_adapter( "satori", "Satori 协议适配器", support_streaming_message=False @@ -137,9 +139,11 @@ async def connect_websocket(self) -> None: logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}") logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}") - if not self.endpoint.startswith(("ws://", "wss://")): - logger.error(f"无效的WebSocket URL: {self.endpoint}") - raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}") + require_secure_transport_url( + self.endpoint, + label="Satori WebSocket URL", + allowed_schemes={"ws", "wss"}, + ) try: websocket = await connect( diff --git a/astrbot/core/platform/sources/websocket_security.py b/astrbot/core/platform/sources/websocket_security.py new file mode 100644 index 0000000000..47d5904830 --- /dev/null +++ b/astrbot/core/platform/sources/websocket_security.py @@ -0,0 +1,77 @@ +import ipaddress +from urllib.parse import SplitResult, urlsplit, urlunsplit + +_ALLOWED_INSECURE_SUFFIXES = (".local", ".internal") + + +def _is_local_or_private_host(hostname: str | None) -> bool: + if not hostname: + return False + + normalized = hostname.strip("[]").lower() + if normalized == "localhost": + return True + if normalized.endswith(_ALLOWED_INSECURE_SUFFIXES): + return True + + try: + address = ipaddress.ip_address(normalized) + except ValueError: + return False + + return address.is_loopback or address.is_private or address.is_link_local + + +def require_secure_transport_url( + url: str, + *, + label: str, + allowed_schemes: set[str], +) -> SplitResult: + parsed = urlsplit(url) + if parsed.scheme not in allowed_schemes: + allowed = ", ".join(sorted(allowed_schemes)) + raise ValueError(f"{label} must use one of: {allowed}") + + if parsed.scheme in {"http", "ws"} and not _is_local_or_private_host( + parsed.hostname + ): + raise ValueError( + f"{label} must use secure transport (https or wss) for non-local endpoints: {url}", + ) + + return parsed + + +def to_websocket_url(url: str, *, label: str = "WebSocket URL") -> str: + normalized_url = url.rstrip("/") + parsed = urlsplit(normalized_url) + allowed_schemes = {"http", "https", "ws", "wss"} + + if parsed.scheme not in allowed_schemes: + raise ValueError( + f"{label} must use the http, https, ws, or wss scheme: {normalized_url}", + ) + + parsed = require_secure_transport_url( + normalized_url, + label=label, + allowed_schemes=allowed_schemes, + ) + scheme_map = { + "http": "ws", + "https": "wss", + "ws": "ws", + "wss": "wss", + } + + try: + ws_scheme = scheme_map[parsed.scheme] + except KeyError as exc: + raise ValueError( + f"{label} must use the http, https, ws, or wss scheme: {normalized_url}", + ) from exc + + return urlunsplit( + parsed._replace(scheme=ws_scheme), + ) diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 53d9441fab..c6c54db17c 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -129,7 +129,6 @@ async def render( if not template_name: template_name = "base" tmpl_str = await self.get_template(name=template_name) - text = text.replace("`", "\\`") return await self.render_custom_template( tmpl_str, {"text": text, "version": f"v{VERSION}"}, diff --git a/astrbot/core/utils/t2i/template/astrbot_powershell.html b/astrbot/core/utils/t2i/template/astrbot_powershell.html index 9ed3e77a55..93237023b6 100644 --- a/astrbot/core/utils/t2i/template/astrbot_powershell.html +++ b/astrbot/core/utils/t2i/template/astrbot_powershell.html @@ -177,8 +177,8 @@ - \ No newline at end of file + diff --git a/astrbot/core/utils/t2i/template/base.html b/astrbot/core/utils/t2i/template/base.html index 257cff3ff4..b38619e2f9 100644 --- a/astrbot/core/utils/t2i/template/base.html +++ b/astrbot/core/utils/t2i/template/base.html @@ -18,7 +18,7 @@
@@ -244,4 +244,4 @@ } } - \ No newline at end of file + diff --git a/dashboard/src/components/chat/message_list_comps/IPythonToolBlock.vue b/dashboard/src/components/chat/message_list_comps/IPythonToolBlock.vue index 5ff9469a0e..252812394d 100644 --- a/dashboard/src/components/chat/message_list_comps/IPythonToolBlock.vue +++ b/dashboard/src/components/chat/message_list_comps/IPythonToolBlock.vue @@ -25,6 +25,7 @@ import { ref, computed, onMounted } from 'vue'; import { useModuleI18n } from '@/i18n/composables'; import { createHighlighter } from 'shiki'; +import DOMPurify from 'dompurify'; const props = defineProps({ toolCall: { @@ -67,6 +68,13 @@ const code = computed(() => { const result = computed(() => props.toolCall.result); +const escapeHtml = (value) => value + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, '''); + const formattedResult = computed(() => { if (!result.value) return ''; try { @@ -82,13 +90,13 @@ const highlightedCode = computed(() => { return ''; } try { - return shikiHighlighter.value.codeToHtml(code.value, { + return DOMPurify.sanitize(shikiHighlighter.value.codeToHtml(code.value, { lang: 'python', theme: props.isDark ? 'min-dark' : 'github-light' - }); + })); } catch (err) { console.error('Failed to highlight code:', err); - return `
${code.value}
`; + return `
${escapeHtml(code.value)}
`; } }); diff --git a/dashboard/src/components/shared/AstrBotConfigV4.vue b/dashboard/src/components/shared/AstrBotConfigV4.vue index 9c86c419a6..4530479219 100644 --- a/dashboard/src/components/shared/AstrBotConfigV4.vue +++ b/dashboard/src/components/shared/AstrBotConfigV4.vue @@ -1,5 +1,6 @@ \ No newline at end of file + diff --git a/docs/scripts/upload_doc_images_to_r2.py b/docs/scripts/upload_doc_images_to_r2.py index 7db614dc47..23df4d485b 100755 --- a/docs/scripts/upload_doc_images_to_r2.py +++ b/docs/scripts/upload_doc_images_to_r2.py @@ -220,7 +220,9 @@ def run_rclone_upload( else: print(f"Uploading to: {target}") - subprocess.run(cmd, check=True) + # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit + # Uses an argv list with shell=False after checking that rclone exists in PATH. + subprocess.run(cmd, check=True, shell=False) finally: tmp_path.unlink(missing_ok=True) diff --git a/docs/tests/test_upload_doc_images_to_r2.py b/docs/tests/test_upload_doc_images_to_r2.py new file mode 100644 index 0000000000..3465d7238d --- /dev/null +++ b/docs/tests/test_upload_doc_images_to_r2.py @@ -0,0 +1,59 @@ +import sys +import unittest +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import patch + + +def load_upload_module(): + script_path = ( + Path(__file__).resolve().parents[1] / "scripts" / "upload_doc_images_to_r2.py" + ) + spec = spec_from_file_location("upload_doc_images_to_r2", script_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load module from {script_path}") + module = module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +class UploadDocImagesToR2Test(unittest.TestCase): + def test_run_rclone_upload_uses_argument_list_without_shell(self): + module = load_upload_module() + + with TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + + with ( + patch.object(module.shutil, "which", return_value="/usr/bin/rclone"), + patch.object(module.subprocess, "run") as mock_run, + ): + module.run_rclone_upload( + root, + "r2:docs-bucket/assets", + ["guide/image.png"], + dry_run=False, + ) + + args, kwargs = mock_run.call_args + files_from_path = args[0][5] + self.assertEqual( + args[0], + [ + "rclone", + "copy", + str(root), + "r2:docs-bucket/assets", + "--files-from", + files_from_path, + "--create-empty-src-dirs", + ], + ) + self.assertTrue(kwargs["check"]) + self.assertIs(kwargs.get("shell"), False) + + +if __name__ == "__main__": + unittest.main() diff --git a/k8s/astrbot/02-deployment.yaml b/k8s/astrbot/02-deployment.yaml index d2799ab900..b7151a8a6e 100644 --- a/k8s/astrbot/02-deployment.yaml +++ b/k8s/astrbot/02-deployment.yaml @@ -17,10 +17,23 @@ spec: labels: app: astrbot-standalone spec: + securityContext: + # Keep these IDs in sync with the `astrbot` user declared in Dockerfile. + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: astrbot image: soulter/astrbot:latest imagePullPolicy: IfNotPresent + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL env: - name: TZ value: "Asia/Shanghai" @@ -46,4 +59,4 @@ spec: - name: localtime hostPath: path: /etc/localtime - type: File \ No newline at end of file + type: File diff --git a/k8s/astrbot_with_napcat/02-deployment.yaml b/k8s/astrbot_with_napcat/02-deployment.yaml index 510e1520c4..3072c9a5f4 100644 --- a/k8s/astrbot_with_napcat/02-deployment.yaml +++ b/k8s/astrbot_with_napcat/02-deployment.yaml @@ -22,11 +22,24 @@ spec: subdomain: astrbot-stack # 优雅关闭时间,给 NapCat 足够时间保存状态 terminationGracePeriodSeconds: 60 + securityContext: + # Keep these IDs in sync with the `astrbot` user declared in Dockerfile. + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault # 初始化容器:首次生成随机 machine-id,后续复用 initContainers: - name: init-machine-id image: busybox:1.37.0 + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL command: - /bin/sh - -c @@ -47,6 +60,11 @@ spec: - name: napcat image: mlikiowa/napcat-docker:latest imagePullPolicy: IfNotPresent + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL env: - name: NAPCAT_UID value: "1000" @@ -86,6 +104,11 @@ spec: - name: astrbot image: soulter/astrbot:latest imagePullPolicy: IfNotPresent + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL env: - name: TZ value: "Asia/Shanghai" diff --git a/tests/unit/test_computer.py b/tests/unit/test_computer.py index 07a5449c19..36dfec48c4 100644 --- a/tests/unit/test_computer.py +++ b/tests/unit/test_computer.py @@ -243,6 +243,28 @@ async def test_exec_with_env(self): class TestLocalPythonComponent: """Tests for LocalPythonComponent.""" + @pytest.mark.asyncio + async def test_exec_uses_fixed_python_executable(self): + """Test Python execution ignores dynamic executable overrides.""" + python = LocalPythonComponent() + + with ( + patch.dict("os.environ", {"PYTHON": "malicious-python"}), + patch("astrbot.core.computer.booters.local.subprocess.run") as mock_run, + ): + mock_run.return_value = MagicMock( + returncode=0, + stdout="hello\n", + stderr="", + ) + + result = await python.exec("print('hello')") + + args, kwargs = mock_run.call_args + assert args[0] == [sys.executable, "-c", "print('hello')"] + assert kwargs.get("shell") is False + assert result["data"]["output"]["text"] == "hello\n" + @pytest.mark.asyncio async def test_exec_simple_code(self): """Test executing simple Python code.""" diff --git a/tests/unit/test_sqlite_v3.py b/tests/unit/test_sqlite_v3.py new file mode 100644 index 0000000000..203faaacd9 --- /dev/null +++ b/tests/unit/test_sqlite_v3.py @@ -0,0 +1,79 @@ +from unittest.mock import MagicMock, patch + +from astrbot.core.db.migration.sqlite_v3 import SQLiteDatabase + + +class TestSQLiteV3QueryParameterization: + def _create_db_with_mock_cursor(self): + db = SQLiteDatabase(":memory:") + cursor = MagicMock() + cursor.fetchall.return_value = [] + cursor.fetchone.return_value = (0,) + + conn = MagicMock() + conn.cursor.return_value = cursor + db.conn = conn + return db, cursor + + def test_get_base_stats_uses_bound_parameter(self): + db, cursor = self._create_db_with_mock_cursor() + + with patch("astrbot.core.db.migration.sqlite_v3.time.time", return_value=1000): + db.get_base_stats(offset_sec=60) + + sql, params = cursor.execute.call_args.args + assert ( + " ".join(sql.split()) + == "SELECT * FROM platform WHERE timestamp >= :min_timestamp" + ) + assert params == {"min_timestamp": 940} + + def test_get_grouped_base_stats_uses_bound_parameter(self): + db, cursor = self._create_db_with_mock_cursor() + + with patch("astrbot.core.db.migration.sqlite_v3.time.time", return_value=1000): + db.get_grouped_base_stats(offset_sec=60) + + sql, params = cursor.execute.call_args.args + assert " ".join(sql.split()) == ( + "SELECT name, SUM(count), timestamp FROM platform " + "WHERE timestamp >= :min_timestamp GROUP BY name" + ) + assert params == {"min_timestamp": 940} + + def test_get_filtered_conversations_uses_named_parameters(self): + db, cursor = self._create_db_with_mock_cursor() + unsafe_search = "x' OR 1=1 --" + + db.get_filtered_conversations( + page=2, + page_size=10, + platforms=["qq"], + message_types=["group"], + search_query=unsafe_search, + exclude_ids=["admin"], + exclude_platforms=["slack"], + ) + + count_sql, count_params = cursor.execute.call_args_list[0].args + data_sql, data_params = cursor.execute.call_args_list[1].args + + assert ":platform_0" in count_sql + assert ":message_type_0" in count_sql + assert ":search_query" in count_sql + assert ":exclude_id_0" in count_sql + assert ":exclude_platform_0" in count_sql + assert unsafe_search not in count_sql + assert count_params["platform_0"] == "qq:%" + assert count_params["message_type_0"] == "%:group:%" + assert unsafe_search in count_params["search_query"] + assert count_params["exclude_id_0"] == "admin%" + assert count_params["exclude_platform_0"] == "slack:%" + + assert "FROM webchat_conversation WHERE" in data_sql + assert unsafe_search not in data_sql + assert ":page_size" in data_sql + assert ":offset" in data_sql + assert unsafe_search in data_params["search_query"] + assert data_params["page_size"] == 10 + assert data_params["offset"] == 10 diff --git a/tests/unit/test_t2i_security.py b/tests/unit/test_t2i_security.py new file mode 100644 index 0000000000..4c17da215b --- /dev/null +++ b/tests/unit/test_t2i_security.py @@ -0,0 +1,33 @@ +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.utils.t2i.network_strategy import NetworkRenderStrategy + + +@pytest.mark.asyncio +async def test_network_strategy_render_preserves_backticks(): + strategy = NetworkRenderStrategy() + strategy.get_template = AsyncMock(return_value="template") + strategy.render_custom_template = AsyncMock(return_value="rendered") + + result = await strategy.render("```python\nprint('hi')\n```") + + assert result == "rendered" + _, tmpl_data, return_url = strategy.render_custom_template.call_args.args + assert tmpl_data["text"] == "```python\nprint('hi')\n```" + assert return_url is False + + +def test_t2i_templates_use_json_serialization_for_text(): + template_paths = sorted( + Path("astrbot/core/utils/t2i/template").glob("*.html"), + ) + + assert template_paths + + for template_path in template_paths: + content = template_path.read_text(encoding="utf-8") + assert "text | safe" not in content + assert "text | tojson" in content diff --git a/tests/unit/test_websocket_security.py b/tests/unit/test_websocket_security.py new file mode 100644 index 0000000000..c979754ab2 --- /dev/null +++ b/tests/unit/test_websocket_security.py @@ -0,0 +1,92 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from astrbot.core.platform.sources.misskey.misskey_api import StreamingClient +from astrbot.core.platform.sources.websocket_security import ( + require_secure_transport_url, + to_websocket_url, +) + + +def test_require_secure_transport_url_allows_local_ws() -> None: + parsed = require_secure_transport_url( + "ws://localhost:5140/satori/v1/events", + label="Satori WebSocket URL", + allowed_schemes={"ws", "wss"}, + ) + + assert parsed.scheme == "ws" + + +def test_require_secure_transport_url_rejects_public_ws() -> None: + with pytest.raises( + ValueError, + match=r"must use secure transport \(https or wss\) for non-local endpoints", + ): + require_secure_transport_url( + "ws://example.com/events", + label="Satori WebSocket URL", + allowed_schemes={"ws", "wss"}, + ) + + +def test_require_secure_transport_url_rejects_bare_hostname_ws() -> None: + with pytest.raises( + ValueError, + match=r"must use secure transport \(https or wss\) for non-local endpoints", + ): + require_secure_transport_url( + "ws://prod/events", + label="Satori WebSocket URL", + allowed_schemes={"ws", "wss"}, + ) + + +def test_to_websocket_url_converts_https_to_wss() -> None: + assert to_websocket_url("https://example.com") == "wss://example.com" + assert ( + to_websocket_url("http://localhost:5140/satori/v1") + == "ws://localhost:5140/satori/v1" + ) + + +def test_to_websocket_url_rejects_unsupported_scheme() -> None: + with pytest.raises( + ValueError, + match="Misskey instance URL must use the http, https, ws, or wss scheme", + ): + to_websocket_url("ftp://example.com", label="Misskey instance URL") + + +@pytest.mark.asyncio +async def test_streaming_client_connects_with_secure_websocket_url() -> None: + client = StreamingClient("https://example.com", "token") + websocket = AsyncMock() + + with patch( + "astrbot.core.platform.sources.misskey.misskey_api.websockets.connect", + new_callable=AsyncMock, + ) as mock_connect: + mock_connect.return_value = websocket + + assert await client.connect() is True + + mock_connect.assert_awaited_once_with( + "wss://example.com/streaming?i=token", + ping_interval=30, + ping_timeout=10, + ) + + +@pytest.mark.asyncio +async def test_streaming_client_rejects_remote_http_instance() -> None: + client = StreamingClient("http://example.com", "token") + + with patch( + "astrbot.core.platform.sources.misskey.misskey_api.websockets.connect", + new_callable=AsyncMock, + ) as mock_connect: + assert await client.connect() is False + + mock_connect.assert_not_awaited()