From e8984ec7a6c1b042e369cdfa4e26eab5b3833704 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Mon, 13 Apr 2026 20:00:48 +0800 Subject: [PATCH] refactor(utils): Introduce lazy imports for BailianClient/GPDBClient and memory_collection module This change refactors the code to use lazy imports for `BailianClient`, `GPDBClient`, and the entire `memory_collection` module. By doing so, we defer the loading of these components until they're actually needed, which helps reduce startup time and avoids importing unnecessary dependencies like tablestore, mem0ai, numpy, fastapi, uvicorn, etc., when simply importing the `agentrun` package. The key changes include: - Moving `BailianClient` and `GPDBClient` imports inside their respective methods using local imports. - Adding conditional imports (`TYPE_CHECKING`) for `memory_collection` items. - Implementing a custom `__getattr__` function that lazily loads modules on demand rather than during initial import. Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/__init__.py | 68 ++++++++++++++++------- agentrun/utils/control_api.py | 14 ++++- pyproject.toml | 6 ++ tests/unittests/utils/test_control_api.py | 12 ++-- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/agentrun/__init__.py b/agentrun/__init__.py index d87fe5f..591fb27 100644 --- a/agentrun/__init__.py +++ b/agentrun/__init__.py @@ -57,22 +57,26 @@ CredentialUpdateInput, RelatedResource, ) -# Memory Collection -from agentrun.memory_collection import ( - EmbedderConfig, - EmbedderConfigConfig, - LLMConfig, - LLMConfigConfig, - MemoryCollection, - MemoryCollectionClient, - MemoryCollectionCreateInput, - MemoryCollectionListInput, - MemoryCollectionListOutput, - MemoryCollectionUpdateInput, - NetworkConfiguration, - VectorStoreConfig, - VectorStoreConfigConfig, -) + +# Memory Collection - 延迟导入以避免 tablestore/mem0ai 等重型依赖 +# Lazy import to avoid heavy dependencies (tablestore, mem0ai, numpy, etc.) +# Type hints for IDE and type checkers +if TYPE_CHECKING: + from agentrun.memory_collection import ( + EmbedderConfig, + EmbedderConfigConfig, + LLMConfig, + LLMConfigConfig, + MemoryCollection, + MemoryCollectionClient, + MemoryCollectionCreateInput, + MemoryCollectionListInput, + MemoryCollectionListOutput, + MemoryCollectionUpdateInput, + NetworkConfiguration, + VectorStoreConfig, + VectorStoreConfigConfig, + ) # Model Service from agentrun.model import ( BackendType, @@ -304,6 +308,24 @@ "Config", ] +# Memory Collection 模块的所有导出(延迟加载) +# Memory Collection module exports (lazy loaded) +_MEMORY_COLLECTION_EXPORTS = { + "MemoryCollection", + "MemoryCollectionClient", + "EmbedderConfig", + "EmbedderConfigConfig", + "LLMConfig", + "LLMConfigConfig", + "NetworkConfiguration", + "VectorStoreConfig", + "VectorStoreConfigConfig", + "MemoryCollectionCreateInput", + "MemoryCollectionUpdateInput", + "MemoryCollectionListInput", + "MemoryCollectionListOutput", +} + # Server 模块的所有导出 _SERVER_EXPORTS = { "AgentRunServer", @@ -346,11 +368,19 @@ def __getattr__(name: str): - """延迟加载 server 模块的导出,避免可选依赖导致导入失败 + """延迟加载 server / memory_collection 模块的导出,避免重型依赖在 + import agentrun 时被立即加载。 - 当用户访问 server 相关的类时,才尝试导入 server 模块。 - 如果 server 可选依赖未安装,会抛出清晰的错误提示。 + Lazy-load server / memory_collection module exports to avoid pulling in + heavy dependencies (tablestore, mem0ai, fastapi, etc.) at import time. """ + # Memory Collection 模块(延迟加载以避免 tablestore/mem0ai 依赖) + if name in _MEMORY_COLLECTION_EXPORTS: + from agentrun import memory_collection + + return getattr(memory_collection, name) + + # Server 模块(延迟加载以避免 fastapi/uvicorn 依赖) if name in _SERVER_EXPORTS: try: from agentrun import server diff --git a/agentrun/utils/control_api.py b/agentrun/utils/control_api.py index b74a822..3b16fbb 100644 --- a/agentrun/utils/control_api.py +++ b/agentrun/utils/control_api.py @@ -4,16 +4,22 @@ This module defines the base class for control API. """ -from typing import Optional +from typing import Optional, TYPE_CHECKING from alibabacloud_agentrun20250910.client import Client as AgentRunClient -from alibabacloud_bailian20231229.client import Client as BailianClient from alibabacloud_devs20230714.client import Client as DevsClient -from alibabacloud_gpdb20160503.client import Client as GPDBClient from alibabacloud_tea_openapi import utils_models as open_api_util_models from agentrun.utils.config import Config +# 延迟导入:BailianClient 和 GPDBClient 仅在 knowledgebase 模块使用, +# 不在顶层导入以减少非 knowledgebase 场景的依赖加载。 +# Lazy import: BailianClient and GPDBClient are only used by the knowledgebase +# module. Deferring import to reduce dependency loading for non-KB scenarios. +if TYPE_CHECKING: + from alibabacloud_bailian20231229.client import Client as BailianClient + from alibabacloud_gpdb20160503.client import Client as GPDBClient + class ControlAPI: """控制链路客户端基类 / Control API Client Base Class @@ -88,6 +94,7 @@ def _get_bailian_client( Returns: BailianClient: 百炼 API 客户端实例 / Bailian API client instance """ + from alibabacloud_bailian20231229.client import Client as BailianClient cfg = Config.with_configs(self.config, config) endpoint = cfg.get_bailian_endpoint() @@ -116,6 +123,7 @@ def _get_gpdb_client(self, config: Optional[Config] = None) -> "GPDBClient": Returns: GPDBClient: GPDB API 客户端实例 / GPDB API client instance """ + from alibabacloud_gpdb20160503.client import Client as GPDBClient cfg = Config.with_configs(self.config, config) # GPDB 使用区域级别的 endpoint / GPDB uses region-level endpoint diff --git a/pyproject.toml b/pyproject.toml index d5d04d6..8aeb91c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,12 @@ dependencies = [ ] [project.optional-dependencies] +# CLI 最小依赖集 / Minimal dependencies for CLI binary packaging +core = [ + "click>=8.0.0", + "rich>=13.0.0", +] + server = [ "fastapi>=0.104.0", "uvicorn>=0.24.0", diff --git a/tests/unittests/utils/test_control_api.py b/tests/unittests/utils/test_control_api.py index efd799f..d69e0b1 100644 --- a/tests/unittests/utils/test_control_api.py +++ b/tests/unittests/utils/test_control_api.py @@ -282,7 +282,7 @@ def test_get_devs_client_with_read_timeout(self, mock_client_class): class TestControlAPIGetBailianClient: """测试 ControlAPI._get_bailian_client""" - @patch("agentrun.utils.control_api.BailianClient") + @patch("alibabacloud_bailian20231229.client.Client") def test_get_bailian_client_basic(self, mock_client_class): """测试获取基本百炼客户端""" config = Config( @@ -304,7 +304,7 @@ def test_get_bailian_client_basic(self, mock_client_class): assert config_arg.access_key_secret == "sk" assert config_arg.region_id == "cn-hangzhou" - @patch("agentrun.utils.control_api.BailianClient") + @patch("alibabacloud_bailian20231229.client.Client") def test_get_bailian_client_strips_https_prefix(self, mock_client_class): """测试获取百炼客户端时去除 https:// 前缀""" config = Config( @@ -323,7 +323,7 @@ def test_get_bailian_client_strips_https_prefix(self, mock_client_class): config_arg = call_args[0][0] assert config_arg.endpoint == "bailian.cn-hangzhou.aliyuncs.com" - @patch("agentrun.utils.control_api.BailianClient") + @patch("alibabacloud_bailian20231229.client.Client") def test_get_bailian_client_strips_http_prefix(self, mock_client_class): """测试获取百炼客户端时去除 http:// 前缀""" config = Config( @@ -346,7 +346,7 @@ def test_get_bailian_client_strips_http_prefix(self, mock_client_class): class TestControlAPIGetGPDBClient: """测试 ControlAPI._get_gpdb_client""" - @patch("agentrun.utils.control_api.GPDBClient") + @patch("alibabacloud_gpdb20160503.client.Client") def test_get_gpdb_client_known_region(self, mock_client_class): """测试已知 region 使用通用 endpoint""" config = Config( @@ -365,7 +365,7 @@ def test_get_gpdb_client_known_region(self, mock_client_class): config_arg = call_args[0][0] assert config_arg.endpoint == "gpdb.aliyuncs.com" - @patch("agentrun.utils.control_api.GPDBClient") + @patch("alibabacloud_gpdb20160503.client.Client") def test_get_gpdb_client_unknown_region(self, mock_client_class): """测试未知 region 使用区域级别 endpoint""" config = Config( @@ -384,7 +384,7 @@ def test_get_gpdb_client_unknown_region(self, mock_client_class): config_arg = call_args[0][0] assert config_arg.endpoint == "gpdb.us-west-1.aliyuncs.com" - @patch("agentrun.utils.control_api.GPDBClient") + @patch("alibabacloud_gpdb20160503.client.Client") def test_get_gpdb_client_all_known_regions(self, mock_client_class): """测试所有已知 region 使用通用 endpoint""" known_regions = [