diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py
index cf7d2e0797..87199cde71 100644
--- a/astrbot/core/computer/booters/local.py
+++ b/astrbot/core/computer/booters/local.py
@@ -141,11 +141,14 @@ async def exec(
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
try:
+ # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
+ # Executes the current interpreter with a fixed argv list and shell=False.
result = subprocess.run(
- [os.environ.get("PYTHON", sys.executable), "-c", code],
+ [sys.executable, "-c", code],
timeout=timeout,
capture_output=True,
text=True,
+ shell=False,
)
stdout = "" if silent else result.stdout
stderr = result.stderr if result.returncode != 0 else ""
diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py
index b326ebb449..bbf65da7fa 100644
--- a/astrbot/core/db/migration/sqlite_v3.py
+++ b/astrbot/core/db/migration/sqlite_v3.py
@@ -164,7 +164,7 @@ def insert_llm_metrics(self, metrics: dict) -> None:
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据"""
- where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
+ min_timestamp = int(time.time()) - offset_sec
try:
c = self.conn.cursor()
@@ -174,8 +174,9 @@ def get_base_stats(self, offset_sec: int = 86400) -> Stats:
c.execute(
"""
SELECT * FROM platform
- """
- + where_clause,
+ WHERE timestamp >= :min_timestamp
+ """,
+ {"min_timestamp": min_timestamp},
)
platform = []
@@ -203,7 +204,7 @@ def get_total_message_count(self) -> int:
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
- where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
+ min_timestamp = int(time.time()) - offset_sec
try:
c = self.conn.cursor()
@@ -213,9 +214,10 @@ def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
c.execute(
"""
SELECT name, SUM(count), timestamp FROM platform
- """
- + where_clause
- + " GROUP BY name",
+ WHERE timestamp >= :min_timestamp
+ GROUP BY name
+ """,
+ {"min_timestamp": min_timestamp},
)
platform = []
@@ -403,14 +405,15 @@ def get_filtered_conversations(
try:
# 构建查询条件
where_clauses = []
- params = []
+ params: dict[str, Any] = {}
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
- for platform in platforms:
- platform_conditions.append("user_id LIKE ?")
- params.append(f"{platform}:%")
+ for index, platform in enumerate(platforms):
+ param_name = f"platform_{index}"
+ platform_conditions.append(f"user_id LIKE :{param_name}")
+ params[param_name] = f"{platform}:%"
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
@@ -418,9 +421,10 @@ def get_filtered_conversations(
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
- for msg_type in message_types:
- message_type_conditions.append("user_id LIKE ?")
- params.append(f"%:{msg_type}:%")
+ for index, msg_type in enumerate(message_types):
+ param_name = f"message_type_{index}"
+ message_type_conditions.append(f"user_id LIKE :{param_name}")
+ params[param_name] = f"%:{msg_type}:%"
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
@@ -429,28 +433,32 @@ def get_filtered_conversations(
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
- "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)",
+ "("
+ "title LIKE :search_query OR user_id LIKE :search_query OR "
+ "cid LIKE :search_query OR history LIKE :search_query"
+ ")",
)
- search_param = f"%{search_query}%"
- params.extend([search_param, search_param, search_param, search_param])
+ params["search_query"] = f"%{search_query}%"
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
- for exclude_id in exclude_ids:
- where_clauses.append("user_id NOT LIKE ?")
- params.append(f"{exclude_id}%")
+ for index, exclude_id in enumerate(exclude_ids):
+ param_name = f"exclude_id_{index}"
+ where_clauses.append(f"user_id NOT LIKE :{param_name}")
+ params[param_name] = f"{exclude_id}%"
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
- for exclude_platform in exclude_platforms:
- where_clauses.append("user_id NOT LIKE ?")
- params.append(f"{exclude_platform}:%")
+ for index, exclude_platform in enumerate(exclude_platforms):
+ param_name = f"exclude_platform_{index}"
+ where_clauses.append(f"user_id NOT LIKE :{param_name}")
+ params[param_name] = f"{exclude_platform}:%"
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
- count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
+ count_sql = "SELECT COUNT(*) FROM webchat_conversation" + where_sql
# 获取总记录数
c.execute(count_sql, params)
@@ -460,14 +468,14 @@ def get_filtered_conversations(
offset = (page - 1) * page_size
# 构建分页数据查询
- data_sql = f"""
- SELECT user_id, cid, created_at, updated_at, title, persona_id
- FROM webchat_conversation
- {where_sql}
- ORDER BY updated_at DESC
- LIMIT ? OFFSET ?
- """
- query_params = params + [page_size, offset]
+ data_sql = (
+ "SELECT user_id, cid, created_at, updated_at, title, persona_id\n"
+ "FROM webchat_conversation"
+ f"{where_sql}\n"
+ "ORDER BY updated_at DESC\n"
+ "LIMIT :page_size OFFSET :offset"
+ )
+ query_params = {**params, "page_size": page_size, "offset": offset}
# 获取分页数据
c.execute(data_sql, query_params)
diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py
index 3e5eb9a90e..d43eae16a5 100644
--- a/astrbot/core/platform/sources/misskey/misskey_api.py
+++ b/astrbot/core/platform/sources/misskey/misskey_api.py
@@ -15,6 +15,7 @@
from astrbot.api import logger
+from ..websocket_security import to_websocket_url
from .misskey_utils import FileIDExtractor
# Constants
@@ -56,10 +57,7 @@ def __init__(self, instance_url: str, access_token: str) -> None:
async def connect(self) -> bool:
try:
- ws_url = self.instance_url.replace("https://", "wss://").replace(
- "http://",
- "ws://",
- )
+ ws_url = to_websocket_url(self.instance_url, label="Misskey instance URL")
ws_url += f"/streaming?i={self.access_token}"
self.websocket = await websockets.connect(
diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py
index 5c2f7a37f3..6458a53658 100644
--- a/astrbot/core/platform/sources/satori/satori_adapter.py
+++ b/astrbot/core/platform/sources/satori/satori_adapter.py
@@ -27,6 +27,8 @@
)
from astrbot.core.platform.astr_message_event import MessageSession
+from ..websocket_security import require_secure_transport_url
+
@register_platform_adapter(
"satori", "Satori 协议适配器", support_streaming_message=False
@@ -137,9 +139,11 @@ async def connect_websocket(self) -> None:
logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}")
logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}")
- if not self.endpoint.startswith(("ws://", "wss://")):
- logger.error(f"无效的WebSocket URL: {self.endpoint}")
- raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
+ require_secure_transport_url(
+ self.endpoint,
+ label="Satori WebSocket URL",
+ allowed_schemes={"ws", "wss"},
+ )
try:
websocket = await connect(
diff --git a/astrbot/core/platform/sources/websocket_security.py b/astrbot/core/platform/sources/websocket_security.py
new file mode 100644
index 0000000000..47d5904830
--- /dev/null
+++ b/astrbot/core/platform/sources/websocket_security.py
@@ -0,0 +1,77 @@
+import ipaddress
+from urllib.parse import SplitResult, urlsplit, urlunsplit
+
+_ALLOWED_INSECURE_SUFFIXES = (".local", ".internal")
+
+
+def _is_local_or_private_host(hostname: str | None) -> bool:
+ if not hostname:
+ return False
+
+ normalized = hostname.strip("[]").lower()
+ if normalized == "localhost":
+ return True
+ if normalized.endswith(_ALLOWED_INSECURE_SUFFIXES):
+ return True
+
+ try:
+ address = ipaddress.ip_address(normalized)
+ except ValueError:
+ return False
+
+ return address.is_loopback or address.is_private or address.is_link_local
+
+
+def require_secure_transport_url(
+ url: str,
+ *,
+ label: str,
+ allowed_schemes: set[str],
+) -> SplitResult:
+ parsed = urlsplit(url)
+ if parsed.scheme not in allowed_schemes:
+ allowed = ", ".join(sorted(allowed_schemes))
+ raise ValueError(f"{label} must use one of: {allowed}")
+
+ if parsed.scheme in {"http", "ws"} and not _is_local_or_private_host(
+ parsed.hostname
+ ):
+ raise ValueError(
+ f"{label} must use secure transport (https or wss) for non-local endpoints: {url}",
+ )
+
+ return parsed
+
+
+def to_websocket_url(url: str, *, label: str = "WebSocket URL") -> str:
+ normalized_url = url.rstrip("/")
+ parsed = urlsplit(normalized_url)
+ allowed_schemes = {"http", "https", "ws", "wss"}
+
+ if parsed.scheme not in allowed_schemes:
+ raise ValueError(
+ f"{label} must use the http, https, ws, or wss scheme: {normalized_url}",
+ )
+
+ parsed = require_secure_transport_url(
+ normalized_url,
+ label=label,
+ allowed_schemes=allowed_schemes,
+ )
+ scheme_map = {
+ "http": "ws",
+ "https": "wss",
+ "ws": "ws",
+ "wss": "wss",
+ }
+
+ try:
+ ws_scheme = scheme_map[parsed.scheme]
+ except KeyError as exc:
+ raise ValueError(
+ f"{label} must use the http, https, ws, or wss scheme: {normalized_url}",
+ ) from exc
+
+ return urlunsplit(
+ parsed._replace(scheme=ws_scheme),
+ )
diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py
index 53d9441fab..c6c54db17c 100644
--- a/astrbot/core/utils/t2i/network_strategy.py
+++ b/astrbot/core/utils/t2i/network_strategy.py
@@ -129,7 +129,6 @@ async def render(
if not template_name:
template_name = "base"
tmpl_str = await self.get_template(name=template_name)
- text = text.replace("`", "\\`")
return await self.render_custom_template(
tmpl_str,
{"text": text, "version": f"v{VERSION}"},
diff --git a/astrbot/core/utils/t2i/template/astrbot_powershell.html b/astrbot/core/utils/t2i/template/astrbot_powershell.html
index 9ed3e77a55..93237023b6 100644
--- a/astrbot/core/utils/t2i/template/astrbot_powershell.html
+++ b/astrbot/core/utils/t2i/template/astrbot_powershell.html
@@ -177,8 +177,8 @@