Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions astrbot/core/platform/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,17 @@ async def load_platform(self, platform_config: dict) -> None:
)
return

platform_id = platform_config["id"]
if platform_id in self._inst_map:
logger.warning(
"平台适配器 %s(%s) 已存在,正在先终止旧实例再重新加载。",
platform_config["type"],
platform_id,
)
await self.terminate_platform(platform_id)

logger.info(
f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...",
f"载入 {platform_config['type']}({platform_id}) 平台适配器 ...",
)
match platform_config["type"]:
case "aiocqhttp":
Expand Down Expand Up @@ -255,24 +264,29 @@ async def reload(self, platform_config: dict) -> None:
await self.terminate_platform(key)

async def terminate_platform(self, platform_id: str) -> None:
if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
tracked_inst: Platform | None = None
info = self._inst_map.pop(platform_id, None)
if info:
tracked_inst = info["inst"]

# client_id = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id)
client_id = info["client_id"]
inst: Platform = info["inst"]
try:
self.platform_insts.remove(
next(
inst
for inst in self.platform_insts
if inst.client_self_id == client_id
),
)
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
insts_to_terminate: list[Platform] = []
if tracked_inst is not None:
insts_to_terminate.append(tracked_inst)

for inst in list(self.platform_insts):
if inst in insts_to_terminate:
continue
if getattr(inst, "config", {}).get("id") == platform_id:
insts_to_terminate.append(inst)

if not insts_to_terminate:
return

logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")

for inst in insts_to_terminate:
while inst in self.platform_insts:
self.platform_insts.remove(inst)
await self._terminate_inst_and_tasks(inst)

async def terminate(self) -> None:
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_platform_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import asyncio

import pytest

from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform.platform import Platform
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.platform.register import platform_cls_map


class DummyAstrBotConfig(dict):
def save_config(self, replace_config: dict | None = None) -> None:
if replace_config is not None:
self.clear()
self.update(replace_config)


class DummyPlatform(Platform):
instances: list["DummyPlatform"] = []

def __init__(self, platform_config: dict, platform_settings: dict, event_queue):
super().__init__(platform_config, event_queue)
self.platform_settings = platform_settings
self.terminated = False
self._stop_event = asyncio.Event()
self.__class__.instances.append(self)

async def _run(self) -> None:
await self._stop_event.wait()

def run(self):
return self._run()

async def terminate(self) -> None:
self.terminated = True
self._stop_event.set()

def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="dummy",
description="dummy platform",
id=self.config["id"],
support_proactive_message=False,
)


@pytest.fixture
def manager(monkeypatch: pytest.MonkeyPatch) -> PlatformManager:
DummyPlatform.instances.clear()
monkeypatch.setitem(platform_cls_map, "dummy", DummyPlatform)
config = DummyAstrBotConfig({"platform": [], "platform_settings": {}})
return PlatformManager(config, asyncio.Queue())


@pytest.mark.asyncio
async def test_load_platform_replaces_existing_same_id(manager: PlatformManager):
config = {"id": "default", "type": "dummy", "enable": True}

await manager.load_platform(config.copy())
first_inst = DummyPlatform.instances[-1]

await manager.load_platform(config.copy())
second_inst = DummyPlatform.instances[-1]

assert first_inst is not second_inst
assert first_inst.terminated is True
assert second_inst.terminated is False
assert manager._inst_map["default"]["inst"] is second_inst
assert manager.platform_insts == [second_inst]


@pytest.mark.asyncio
async def test_terminate_platform_cleans_orphaned_instances(manager: PlatformManager):
config = {"id": "default", "type": "dummy", "enable": True}

await manager.load_platform(config.copy())
tracked_inst = DummyPlatform.instances[-1]

orphan_inst = DummyPlatform(config.copy(), {}, asyncio.Queue())
manager.platform_insts.append(orphan_inst)
manager._start_platform_task("orphan_default", orphan_inst)

await manager.terminate_platform("default")

assert tracked_inst.terminated is True
assert orphan_inst.terminated is True
assert manager.platform_insts == []
assert "default" not in manager._inst_map