diff --git a/graphgen/bases/base_llm_wrapper.py b/graphgen/bases/base_llm_wrapper.py index 8b6dbec7..2755da76 100644 --- a/graphgen/bases/base_llm_wrapper.py +++ b/graphgen/bases/base_llm_wrapper.py @@ -26,11 +26,11 @@ def __init__( **kwargs: Any, ): self.system_prompt = system_prompt - self.temperature = temperature - self.max_tokens = max_tokens - self.repetition_penalty = repetition_penalty - self.top_p = top_p - self.top_k = top_k + self.temperature = float(temperature) + self.max_tokens = int(max_tokens) + self.repetition_penalty = float(repetition_penalty) + self.top_p = float(top_p) + self.top_k = int(top_k) self.tokenizer = tokenizer for k, v in kwargs.items(): diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index 6ae5bf3b..2f01e511 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -1,6 +1,7 @@ import math import uuid from typing import Any, List, Optional +import asyncio from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token @@ -19,12 +20,9 @@ def __init__( temperature: float = 0.6, top_p: float = 1.0, top_k: int = 5, + timeout: float = 300, **kwargs: Any, ): - temperature = float(temperature) - top_p = float(top_p) - top_k = int(top_k) - super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs) try: from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams @@ -43,6 +41,7 @@ def __init__( disable_log_stats=False, ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) + self.timeout = float(timeout) @staticmethod def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: @@ -58,6 +57,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: lines.append(prompt) return "\n".join(lines) + async def _consume_generator(self, generator): + final_output = None + async for request_output in generator: + final_output = request_output + return final_output + async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: @@ -72,14 +77,21 @@ async def generate_answer( result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) - final_output = None - async for request_output in result_generator: - final_output = request_output + try: + final_output = await asyncio.wait_for( + self._consume_generator(result_generator), + timeout=self.timeout + ) + + if not final_output or not final_output.outputs: + return "" - if not final_output or not final_output.outputs: - return "" + result_text = final_output.outputs[0].text + return result_text - return final_output.outputs[0].text + except (Exception, asyncio.CancelledError): + await self.engine.abort(request_id) + raise async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any @@ -91,42 +103,47 @@ async def generate_topk_per_token( temperature=0, max_tokens=1, logprobs=self.top_k, - prompt_logprobs=1, ) result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) - final_output = None - async for request_output in result_generator: - final_output = request_output - - if ( - not final_output - or not final_output.outputs - or not final_output.outputs[0].logprobs - ): - return [] - - top_logprobs = final_output.outputs[0].logprobs[0] - - candidate_tokens = [] - for _, logprob_obj in top_logprobs.items(): - tok_str = ( - logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" + try: + final_output = await asyncio.wait_for( + self._consume_generator(result_generator), + timeout=self.timeout ) - prob = float(math.exp(logprob_obj.logprob)) - candidate_tokens.append(Token(tok_str, prob)) - candidate_tokens.sort(key=lambda x: -x.prob) + if ( + not final_output + or not final_output.outputs + or not final_output.outputs[0].logprobs + ): + return [] + + top_logprobs = final_output.outputs[0].logprobs[0] + + candidate_tokens = [] + for _, logprob_obj in top_logprobs.items(): + tok_str = ( + logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" + ) + prob = float(math.exp(logprob_obj.logprob)) + candidate_tokens.append(Token(tok_str, prob)) + + candidate_tokens.sort(key=lambda x: -x.prob) + + if candidate_tokens: + main_token = Token( + text=candidate_tokens[0].text, + prob=candidate_tokens[0].prob, + top_candidates=candidate_tokens, + ) + return [main_token] + return [] - if candidate_tokens: - main_token = Token( - text=candidate_tokens[0].text, - prob=candidate_tokens[0].prob, - top_candidates=candidate_tokens, - ) - return [main_token] - return [] + except (Exception, asyncio.CancelledError): + await self.engine.abort(request_id) + raise async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any