Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ repos:
rev: 23.7.0
hooks:
- id: black
language_version: python3.12
args: [--line-length=100]

- repo: https://github.com/pycqa/isort
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ class ModelConfig:
class InferenceModelConfig:
# ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path
model_path: Optional[str] = None
name: Optional[str] = None

engine_type: str = "vllm"
engine_num: int = 1
Expand Down
14 changes: 14 additions & 0 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def get_model_path(self) -> Optional[str]:
"""Get the model path"""
return None

def get_model_name(self) -> Optional[str]:
"""Get the name of the model."""
return None


def _history_recorder(func):
"""Decorator to record history of the model calls."""
Expand Down Expand Up @@ -279,6 +283,16 @@ async def model_path_async(self) -> str:
"""Get the model path."""
return await self.model.get_model_path.remote()

@property
def model_name(self) -> Optional[str]:
"""Get the name of the model."""
return ray.get(self.model.get_model_name.remote())

@property
async def model_name_async(self) -> Optional[str]:
"""Get the name of the model."""
return await self.model.get_model_name.remote()

def get_lora_request(self) -> Any:
if self.enable_lora:
return ray.get(self.model.get_lora_request.remote())
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
import vllm
from vllm.sampling_params import RequestOutputKind

self.name = config.name
self.logger = get_logger(__name__)
self.vllm_version = get_vllm_version()
self.config = config
Expand Down Expand Up @@ -718,6 +719,9 @@ def get_model_version(self) -> int:
def get_model_path(self) -> str:
return self.config.model_path # type: ignore [return-value]

def get_model_name(self) -> Optional[str]:
return self.name # type: ignore [return-value]

def get_lora_request(self, lora_path: Optional[str] = None) -> Any:
from vllm.lora.request import LoRARequest

Expand Down
19 changes: 12 additions & 7 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,18 @@ def __init__(
"top_logprobs": self.task.rollout_args.logprobs,
},
)
self.auxiliary_chat_models = [
TrinityChatModel(
openai_async_client=aux_model,
# TODO: customize generate_kwargs for auxiliary models if needed
)
for aux_model in (self.auxiliary_models or [])
]

# TODO: customize generate_kwargs for auxiliary models if needed
if self.auxiliary_model_wrappers is not None and self.auxiliary_models is not None:
self.auxiliary_chat_models = {
aux_model_wrapper.model_name
or f"auxiliary_model_{i}": TrinityChatModel(openai_async_client=aux_model)
for i, (aux_model_wrapper, aux_model) in enumerate(
zip(self.auxiliary_model_wrappers, self.auxiliary_models)
)
}
Comment on lines +123 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of aux_model_wrapper.model_name within the __init__ method introduces a blocking ray.get() call. Since this workflow is instantiated within an async context, this will block the event loop, potentially causing significant performance degradation or even deadlocks.

I recommend deferring the initialization of self.auxiliary_chat_models to an async method, such as run_async, where you can use the non-blocking model_name_async property.

Here is a suggested refactoring:

  1. In __init__, initialize self.auxiliary_chat_models to None:

    # In __init__
    self.auxiliary_chat_models: Optional[Dict[str, "TrinityChatModel"]] = None
  2. At the beginning of run_async, populate the dictionary asynchronously:

    # At the beginning of run_async
    if self.auxiliary_chat_models is None:
        if self.auxiliary_model_wrappers is not None and self.auxiliary_models is not None:
            model_names = await asyncio.gather(
                *[w.model_name_async for w in self.auxiliary_model_wrappers]
            )
            self.auxiliary_chat_models = {
                name or f"auxiliary_model_{i}": TrinityChatModel(openai_async_client=aux_model)
                for i, (name, aux_model) in enumerate(zip(model_names, self.auxiliary_models))
            }
        else:
            self.auxiliary_chat_models = {}

This change will ensure non-blocking behavior and proper use of async capabilities.

else:
self.auxiliary_chat_models = {}

def construct_experiences(
self,
Expand Down