diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index 2f01e511..cafe6529 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -20,7 +20,7 @@ def __init__( temperature: float = 0.6, top_p: float = 1.0, top_k: int = 5, - timeout: float = 300, + timeout: float = 600, **kwargs: Any, ): super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs) @@ -42,25 +42,24 @@ def __init__( ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.timeout = float(timeout) + self.tokenizer = self.engine.engine.tokenizer.tokenizer - @staticmethod - def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: - msgs = history or [] - lines = [] - for m in msgs: - if isinstance(m, dict): - role = m.get("role", "") - content = m.get("content", "") - lines.append(f"{role}: {content}") - else: - lines.append(str(m)) - lines.append(prompt) - return "\n".join(lines) + def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any: + messages = history or [] + messages.append({"role": "user", "content": prompt}) + + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) async def _consume_generator(self, generator): final_output = None async for request_output in generator: - final_output = request_output + if request_output.finished: + final_output = request_output + break return final_output async def generate_answer( @@ -70,14 +69,14 @@ async def generate_answer( request_id = f"graphgen_req_{uuid.uuid4()}" sp = self.SamplingParams( - temperature=self.temperature if self.temperature > 0 else 1.0, - top_p=self.top_p if self.temperature > 0 else 1.0, + temperature=self.temperature if self.temperature >= 0 else 1.0, + top_p=self.top_p if self.top_p >= 0 else 1.0, max_tokens=extra.get("max_new_tokens", 2048), + repetition_penalty=extra.get("repetition_penalty", 1.05), ) - result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) - try: + result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) final_output = await asyncio.wait_for( self._consume_generator(result_generator), timeout=self.timeout @@ -89,7 +88,7 @@ async def generate_answer( result_text = final_output.outputs[0].text return result_text - except (Exception, asyncio.CancelledError): + except (Exception, asyncio.CancelledError, asyncio.TimeoutError): await self.engine.abort(request_id) raise @@ -105,14 +104,14 @@ async def generate_topk_per_token( logprobs=self.top_k, ) - result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) - try: + result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) final_output = await asyncio.wait_for( self._consume_generator(result_generator), timeout=self.timeout ) + if ( not final_output or not final_output.outputs @@ -141,7 +140,7 @@ async def generate_topk_per_token( return [main_token] return [] - except (Exception, asyncio.CancelledError): + except (Exception, asyncio.CancelledError, asyncio.TimeoutError): await self.engine.abort(request_id) raise