-
Notifications
You must be signed in to change notification settings - Fork 63
fix: fix timeout error in vllmwrapper #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
59af943
9e7f8a2
95050df
bf14c38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||
| import math | ||||||
| import uuid | ||||||
| from typing import Any, List, Optional | ||||||
| import asyncio | ||||||
|
|
||||||
| from graphgen.bases.base_llm_wrapper import BaseLLMWrapper | ||||||
| from graphgen.bases.datatypes import Token | ||||||
|
|
@@ -19,12 +20,9 @@ def __init__( | |||||
| temperature: float = 0.6, | ||||||
| top_p: float = 1.0, | ||||||
| top_k: int = 5, | ||||||
| timeout: float = 300, | ||||||
| **kwargs: Any, | ||||||
| ): | ||||||
| temperature = float(temperature) | ||||||
| top_p = float(top_p) | ||||||
| top_k = int(top_k) | ||||||
|
|
||||||
| super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs) | ||||||
| try: | ||||||
| from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams | ||||||
|
|
@@ -43,6 +41,7 @@ def __init__( | |||||
| disable_log_stats=False, | ||||||
| ) | ||||||
| self.engine = AsyncLLMEngine.from_engine_args(engine_args) | ||||||
| self.timeout = float(timeout) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: | ||||||
|
|
@@ -58,6 +57,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: | |||||
| lines.append(prompt) | ||||||
| return "\n".join(lines) | ||||||
|
|
||||||
| async def _consume_generator(self, generator): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| final_output = None | ||||||
| async for request_output in generator: | ||||||
| final_output = request_output | ||||||
| return final_output | ||||||
|
|
||||||
| async def generate_answer( | ||||||
| self, text: str, history: Optional[List[str]] = None, **extra: Any | ||||||
| ) -> str: | ||||||
|
|
@@ -72,14 +77,21 @@ async def generate_answer( | |||||
|
|
||||||
| result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) | ||||||
|
|
||||||
| final_output = None | ||||||
| async for request_output in result_generator: | ||||||
| final_output = request_output | ||||||
| try: | ||||||
| final_output = await asyncio.wait_for( | ||||||
| self._consume_generator(result_generator), | ||||||
| timeout=self.timeout | ||||||
| ) | ||||||
|
|
||||||
| if not final_output or not final_output.outputs: | ||||||
| return "" | ||||||
|
|
||||||
| if not final_output or not final_output.outputs: | ||||||
| return "" | ||||||
| result_text = final_output.outputs[0].text | ||||||
| return result_text | ||||||
|
|
||||||
| return final_output.outputs[0].text | ||||||
| except (Exception, asyncio.CancelledError): | ||||||
| await self.engine.abort(request_id) | ||||||
| raise | ||||||
|
Comment on lines
+80
to
+94
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The timeout handling logic in this For example, you could create a helper like this: async def _run_generation(self, result_generator, request_id):
try:
return await asyncio.wait_for(
self._consume_generator(result_generator),
timeout=self.timeout
)
except (Exception, asyncio.CancelledError):
await self.engine.abort(request_id)
raise |
||||||
|
|
||||||
| async def generate_topk_per_token( | ||||||
| self, text: str, history: Optional[List[str]] = None, **extra: Any | ||||||
|
|
@@ -91,42 +103,47 @@ async def generate_topk_per_token( | |||||
| temperature=0, | ||||||
| max_tokens=1, | ||||||
| logprobs=self.top_k, | ||||||
| prompt_logprobs=1, | ||||||
| ) | ||||||
|
|
||||||
| result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) | ||||||
|
|
||||||
| final_output = None | ||||||
| async for request_output in result_generator: | ||||||
| final_output = request_output | ||||||
|
|
||||||
| if ( | ||||||
| not final_output | ||||||
| or not final_output.outputs | ||||||
| or not final_output.outputs[0].logprobs | ||||||
| ): | ||||||
| return [] | ||||||
|
|
||||||
| top_logprobs = final_output.outputs[0].logprobs[0] | ||||||
|
|
||||||
| candidate_tokens = [] | ||||||
| for _, logprob_obj in top_logprobs.items(): | ||||||
| tok_str = ( | ||||||
| logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" | ||||||
| try: | ||||||
| final_output = await asyncio.wait_for( | ||||||
| self._consume_generator(result_generator), | ||||||
| timeout=self.timeout | ||||||
| ) | ||||||
| prob = float(math.exp(logprob_obj.logprob)) | ||||||
| candidate_tokens.append(Token(tok_str, prob)) | ||||||
|
|
||||||
| candidate_tokens.sort(key=lambda x: -x.prob) | ||||||
| if ( | ||||||
| not final_output | ||||||
| or not final_output.outputs | ||||||
| or not final_output.outputs[0].logprobs | ||||||
| ): | ||||||
| return [] | ||||||
|
|
||||||
| top_logprobs = final_output.outputs[0].logprobs[0] | ||||||
|
|
||||||
| candidate_tokens = [] | ||||||
| for _, logprob_obj in top_logprobs.items(): | ||||||
| tok_str = ( | ||||||
| logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" | ||||||
| ) | ||||||
| prob = float(math.exp(logprob_obj.logprob)) | ||||||
| candidate_tokens.append(Token(tok_str, prob)) | ||||||
|
|
||||||
| candidate_tokens.sort(key=lambda x: -x.prob) | ||||||
|
|
||||||
| if candidate_tokens: | ||||||
| main_token = Token( | ||||||
| text=candidate_tokens[0].text, | ||||||
| prob=candidate_tokens[0].prob, | ||||||
| top_candidates=candidate_tokens, | ||||||
| ) | ||||||
| return [main_token] | ||||||
| return [] | ||||||
|
|
||||||
| if candidate_tokens: | ||||||
| main_token = Token( | ||||||
| text=candidate_tokens[0].text, | ||||||
| prob=candidate_tokens[0].prob, | ||||||
| top_candidates=candidate_tokens, | ||||||
| ) | ||||||
| return [main_token] | ||||||
| return [] | ||||||
| except (Exception, asyncio.CancelledError): | ||||||
| await self.engine.abort(request_id) | ||||||
| raise | ||||||
|
|
||||||
| async def generate_inputs_prob( | ||||||
| self, text: str, history: Optional[List[str]] = None, **extra: Any | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better code clarity and maintainability, it's good practice to add type hints to method signatures. Since the specific vLLM types are not imported at the module level, using
typing.Anyis a reasonable approach here.