From 229db7c17df2c79e42f74ae2f5a4c30cf18f3cec Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 30 Dec 2025 19:22:26 +0800 Subject: [PATCH 1/2] add model_name --- trinity/common/config.py | 1 + trinity/common/models/model.py | 14 ++++++++++++++ trinity/common/models/vllm_model.py | 4 ++++ .../common/workflows/agentscope_workflow.py | 19 ++++++++++++------- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 2a43b2e235..f33d7e5eef 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -477,6 +477,7 @@ class ModelConfig: class InferenceModelConfig: # ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path model_path: Optional[str] = None + name: Optional[str] = None engine_type: str = "vllm" engine_num: int = 1 diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 3fe6f2bf37..079841801e 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -66,6 +66,10 @@ def get_model_path(self) -> Optional[str]: """Get the model path""" return None + def get_model_name(self) -> Optional[str]: + """Get the name of the model.""" + return None + def _history_recorder(func): """Decorator to record history of the model calls.""" @@ -279,6 +283,16 @@ async def model_path_async(self) -> str: """Get the model path.""" return await self.model.get_model_path.remote() + @property + def model_name(self) -> Optional[str]: + """Get the name of the model.""" + return ray.get(self.model.get_model_name.remote()) + + @property + async def model_name_async(self) -> Optional[str]: + """Get the name of the model.""" + return await self.model.get_model_name.remote() + def get_lora_request(self) -> Any: if self.enable_lora: return ray.get(self.model.get_lora_request.remote()) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 43e7c852ac..f4c810d2f2 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -39,6 +39,7 @@ def __init__( import vllm from vllm.sampling_params import RequestOutputKind + self.name = config.name self.logger = get_logger(__name__) self.vllm_version = get_vllm_version() self.config = config @@ -718,6 +719,9 @@ def get_model_version(self) -> int: def get_model_path(self) -> str: return self.config.model_path # type: ignore [return-value] + def get_model_name(self) -> Optional[str]: + return self.name # type: ignore [return-value] + def get_lora_request(self, lora_path: Optional[str] = None) -> Any: from vllm.lora.request import LoRARequest diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index a29594c1c0..c95beb56f4 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -118,13 +118,18 @@ def __init__( "top_logprobs": self.task.rollout_args.logprobs, }, ) - self.auxiliary_chat_models = [ - TrinityChatModel( - openai_async_client=aux_model, - # TODO: customize generate_kwargs for auxiliary models if needed - ) - for aux_model in (self.auxiliary_models or []) - ] + + # TODO: customize generate_kwargs for auxiliary models if needed + if self.auxiliary_model_wrappers is not None and self.auxiliary_models is not None: + self.auxiliary_chat_models = { + aux_model_wrapper.model_name + or f"auxiliary_model_{i}": TrinityChatModel(openai_async_client=aux_model) + for i, (aux_model_wrapper, aux_model) in enumerate( + zip(self.auxiliary_model_wrappers, self.auxiliary_models) + ) + } + else: + self.auxiliary_chat_models = {} def construct_experiences( self, From f1d8cf8b890c6ca522bd67a2763eabdca115c312 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 30 Dec 2025 19:32:26 +0800 Subject: [PATCH 2/2] fix pre-commit config --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91a9702b19..eda9333b2a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,6 @@ repos: rev: 23.7.0 hooks: - id: black - language_version: python3.12 args: [--line-length=100] - repo: https://github.com/pycqa/isort