From 59af943adcdff3305958479730d98ed6c547c62a Mon Sep 17 00:00:00 2001 From: chenzihong_gavin <522023320011@smail.nju.edu.cn> Date: Sat, 27 Dec 2025 12:21:57 +0800 Subject: [PATCH 1/3] fix: fix timeout error in vllmwrapper --- graphgen/models/llm/local/vllm_wrapper.py | 111 ++++++++++++++-------- 1 file changed, 73 insertions(+), 38 deletions(-) diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index fc412b51..74dc3c4e 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -1,6 +1,7 @@ import math import uuid from typing import Any, List, Optional +import asyncio from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token @@ -19,6 +20,7 @@ def __init__( temperature: float = 0.6, top_p: float = 1.0, topk: int = 5, + timeout: float = 300.0, **kwargs: Any, ): super().__init__(temperature=temperature, top_p=top_p, **kwargs) @@ -42,6 +44,7 @@ def __init__( self.temperature = temperature self.top_p = top_p self.topk = topk + self.timeout = timeout @staticmethod def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: @@ -57,6 +60,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: lines.append(prompt) return "\n".join(lines) + async def _consume_generator(self, generator): + final_output = None + async for request_output in generator: + final_output = request_output + return final_output + async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: @@ -71,14 +80,27 @@ async def generate_answer( result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) - final_output = None - async for request_output in result_generator: - final_output = request_output - - if not final_output or not final_output.outputs: - return "" - - return final_output.outputs[0].text + try: + final_output = await asyncio.wait_for( + self._consume_generator(result_generator), + timeout=self.timeout + ) + + if not final_output or not final_output.outputs: + return "" + + result_text = final_output.outputs[0].text + return result_text + + except asyncio.TimeoutError: + await self.engine.abort(request_id) + raise + except asyncio.CancelledError: + await self.engine.abort(request_id) + raise + except Exception as e: + await self.engine.abort(request_id) + raise async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any @@ -95,37 +117,49 @@ async def generate_topk_per_token( result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) - final_output = None - async for request_output in result_generator: - final_output = request_output - - if ( - not final_output - or not final_output.outputs - or not final_output.outputs[0].logprobs - ): - return [] - - top_logprobs = final_output.outputs[0].logprobs[0] - - candidate_tokens = [] - for _, logprob_obj in top_logprobs.items(): - tok_str = ( - logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" - ) - prob = float(math.exp(logprob_obj.logprob)) - candidate_tokens.append(Token(tok_str, prob)) - - candidate_tokens.sort(key=lambda x: -x.prob) - - if candidate_tokens: - main_token = Token( - text=candidate_tokens[0].text, - prob=candidate_tokens[0].prob, - top_candidates=candidate_tokens, + try: + final_output = await asyncio.wait_for( + self._consume_generator(result_generator), + timeout=self.timeout ) - return [main_token] - return [] + + if ( + not final_output + or not final_output.outputs + or not final_output.outputs[0].logprobs + ): + return [] + + top_logprobs = final_output.outputs[0].logprobs[0] + + candidate_tokens = [] + for _, logprob_obj in top_logprobs.items(): + tok_str = ( + logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" + ) + prob = float(math.exp(logprob_obj.logprob)) + candidate_tokens.append(Token(tok_str, prob)) + + candidate_tokens.sort(key=lambda x: -x.prob) + + if candidate_tokens: + main_token = Token( + text=candidate_tokens[0].text, + prob=candidate_tokens[0].prob, + top_candidates=candidate_tokens, + ) + return [main_token] + return [] + + except asyncio.TimeoutError: + await self.engine.abort(request_id) + raise + except asyncio.CancelledError: + await self.engine.abort(request_id) + raise + except Exception as e: + await self.engine.abort(request_id) + raise async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any @@ -133,3 +167,4 @@ async def generate_inputs_prob( raise NotImplementedError( "VLLMWrapper does not support per-token logprobs yet." ) + From 95050df359fc89119ae4edb2b5ecd79fe78fa4bf Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 8 Jan 2026 23:52:58 +0800 Subject: [PATCH 2/3] fix: delete useless prompt_logprobs=1 --- graphgen/bases/base_llm_wrapper.py | 10 +++---- graphgen/models/llm/local/vllm_wrapper.py | 35 ++++++----------------- 2 files changed, 13 insertions(+), 32 deletions(-) diff --git a/graphgen/bases/base_llm_wrapper.py b/graphgen/bases/base_llm_wrapper.py index 8b6dbec7..2755da76 100644 --- a/graphgen/bases/base_llm_wrapper.py +++ b/graphgen/bases/base_llm_wrapper.py @@ -26,11 +26,11 @@ def __init__( **kwargs: Any, ): self.system_prompt = system_prompt - self.temperature = temperature - self.max_tokens = max_tokens - self.repetition_penalty = repetition_penalty - self.top_p = top_p - self.top_k = top_k + self.temperature = float(temperature) + self.max_tokens = int(max_tokens) + self.repetition_penalty = float(repetition_penalty) + self.top_p = float(top_p) + self.top_k = int(top_k) self.tokenizer = tokenizer for k, v in kwargs.items(): diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index ad9491a6..534d62f6 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -20,13 +20,9 @@ def __init__( temperature: float = 0.6, top_p: float = 1.0, top_k: int = 5, - timeout: float = 300 + timeout: float = 300, **kwargs: Any, ): - temperature = float(temperature) - top_p = float(top_p) - top_k = int(top_k) - super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs) try: from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams @@ -45,10 +41,7 @@ def __init__( disable_log_stats=False, ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - self.timeout = timeout + self.timeout = float(timeout) @staticmethod def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: @@ -89,20 +82,15 @@ async def generate_answer( self._consume_generator(result_generator), timeout=self.timeout ) - + if not final_output or not final_output.outputs: return "" result_text = final_output.outputs[0].text return result_text - - except asyncio.TimeoutError: - await self.engine.abort(request_id) - raise - except asyncio.CancelledError: - await self.engine.abort(request_id) - raise + except Exception as e: + print(f"Error in generate_answer: {e}") await self.engine.abort(request_id) raise @@ -116,7 +104,6 @@ async def generate_topk_per_token( temperature=0, max_tokens=1, logprobs=self.top_k, - prompt_logprobs=1, ) result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) @@ -126,7 +113,7 @@ async def generate_topk_per_token( self._consume_generator(result_generator), timeout=self.timeout ) - + if ( not final_output or not final_output.outputs @@ -154,14 +141,9 @@ async def generate_topk_per_token( ) return [main_token] return [] - - except asyncio.TimeoutError: - await self.engine.abort(request_id) - raise - except asyncio.CancelledError: - await self.engine.abort(request_id) - raise + except Exception as e: + print(f"Error in generate_topk_per_token: {e}") await self.engine.abort(request_id) raise @@ -171,4 +153,3 @@ async def generate_inputs_prob( raise NotImplementedError( "VLLMWrapper does not support per-token logprobs yet." ) - From bf14c3833b3a28cdb67d4530f5507da28aa9418e Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 8 Jan 2026 23:55:37 +0800 Subject: [PATCH 3/3] fix: catch CancellEderror --- graphgen/models/llm/local/vllm_wrapper.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index 534d62f6..2f01e511 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -89,8 +89,7 @@ async def generate_answer( result_text = final_output.outputs[0].text return result_text - except Exception as e: - print(f"Error in generate_answer: {e}") + except (Exception, asyncio.CancelledError): await self.engine.abort(request_id) raise @@ -142,8 +141,7 @@ async def generate_topk_per_token( return [main_token] return [] - except Exception as e: - print(f"Error in generate_topk_per_token: {e}") + except (Exception, asyncio.CancelledError): await self.engine.abort(request_id) raise