diff --git a/pyproject.toml b/pyproject.toml index b7e3227a0a..f741f8151c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "sortedcontainers", "word2number", "transformers", + "tinker", ] [project.scripts] diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 61fd03e675..7b545d52f3 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1219,3 +1219,91 @@ async def test_generate(self): response.prompt_length, 40960 ) # If not long enough, please add more files to prompt self.assertGreater(response.logprobs.shape[0], 1000) + + +class TestTinkerAPI(RayUnittestBaseAysnc): + """Test the Tinker API integration with the vLLM engine.""" + + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.explorer.rollout_model.engine_type = "vllm" + self.config.explorer.rollout_model.engine_num = 1 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.enable_openai_api = True + + self.config.check_and_update() + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + + async def test_tinker_api(self): + from tinker import types + from transformers import AutoTokenizer + + engine = self.engines[0] + tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, + ] + result_dict = tokenizer.apply_chat_template( + messages, + chat_template=CHAT_TEMPLATE, + add_generation_prompt=False, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + return_assistant_tokens_mask=True, + return_dict=True, + ) + prompt = types.ModelInput.from_ints( + result_dict["input_ids"][0].tolist(), + ) + # sample api without prompt logprobs + num_samples = 4 + response = await engine.sample.remote( + prompt=prompt, + num_samples=num_samples, + sampling_params=types.SamplingParams(temperature=0.7), # no limit on length + ) + self.assertEqual(len(response.sequences), num_samples) + for sequence in response.sequences: + self.assertEqual(len(sequence.tokens), len(sequence.logprobs)) + self.assertEqual(sequence.stop_reason, "stop") + self.assertIsNone(response.prompt_logprobs) + self.assertIsNone(response.topk_prompt_logprobs) + # sample api with prompt logprobs + num_samples = 2 + topk_prompt_logprobs = 3 + response = await engine.sample.remote( + prompt=prompt, + num_samples=num_samples, + sampling_params=types.SamplingParams(temperature=0.7, max_tokens=8), + include_prompt_logprobs=True, + topk_prompt_logprobs=topk_prompt_logprobs, + ) + self.assertEqual(len(response.sequences), num_samples) + for sequence in response.sequences: + self.assertEqual(len(sequence.tokens), len(sequence.logprobs)) + self.assertEqual(sequence.stop_reason, "length") + self.assertEqual(len(response.prompt_logprobs), len(prompt.to_ints())) + self.assertIsNone(response.prompt_logprobs[0]) + self.assertEqual(len(response.topk_prompt_logprobs), len(prompt.to_ints())) + self.assertIsNone(response.topk_prompt_logprobs[0]) + for topk_logprobs in response.topk_prompt_logprobs[1:]: + self.assertIsNotNone(topk_logprobs) + self.assertEqual(len(topk_logprobs), topk_prompt_logprobs) + # compute_logprob api + response = await engine.sample.remote( + prompt=prompt, + num_samples=1, + sampling_params=types.SamplingParams(max_tokens=1), + include_prompt_logprobs=True, + ) + self.assertEqual(len(response.sequences), 1) + self.assertEqual(response.sequences[0].stop_reason, "length") + self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs)) + self.assertIsNone(response.topk_prompt_logprobs) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 60131f13aa..43e7c852ac 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -3,7 +3,7 @@ import asyncio import os from collections import defaultdict -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import ray @@ -402,6 +402,92 @@ async def logprobs( # type: ignore [override] dtype=torch.float32, ) + async def sample( + self, + prompt: Any, + num_samples: int, + sampling_params: Any, + include_prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0, + lora_request: Optional[Any] = None, + ) -> Any: + """Tinker compatible sampling interface. + + Args: + prompt (ModelInput): The input prompt. + num_samples (int): The number of samples to generate. + sampling_params (SamplingParams): The sampling parameters. + include_prompt_logprobs (bool): Whether to include prompt logprobs. + topk_prompt_logprobs (int): The top-k prompt logprobs to include. + lora_request (LoRARequest, optional): The LoRA request. Defaults to None. + Returns: + SampleResponse: The sample response. + """ + from tinker.types import SampledSequence, SampleResponse + + params = { + "max_tokens": sampling_params.max_tokens + if sampling_params.max_tokens is not None + else self.config.max_response_tokens, + "seed": sampling_params.seed if sampling_params.seed is not None else self.config.seed, + "top_k": sampling_params.top_k, + "top_p": sampling_params.top_p, + "temperature": sampling_params.temperature, + "n": num_samples, + "prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None), + # in vLLM, 0 means only return the chosen token's logprob + "logprobs": 0, + } + if sampling_params.stop is not None: + params["stop"] = sampling_params.stop + req_output = await self._generate_internal( + prompt={"prompt_token_ids": prompt.to_ints()}, + lora_request=lora_request, + **params, + ) + sequences = [] + # vLLM's prompt_logprobs output does not include a value for the first token. + # Initialize with [None] to align with the prompt tokens. + topk_prompt_logprobs_list: List[Optional[List[Tuple[int, float]]]] = [None] + prompt_logprobs: List[Optional[float]] = [None] + + # collect prompt logprobs + if include_prompt_logprobs: + for logprob_dict in req_output.prompt_logprobs[1:]: + prompt_logprobs.append(next(iter(logprob_dict.values())).logprob) + if topk_prompt_logprobs > 0: + # collect top-k prompt logprobs + # logprob_dict: {token_id: Logprob(logprob, rank, ...), ...} + logprob_items = list(logprob_dict.items()) + # sort by Logprob.rank + logprob_items_sorted = sorted(logprob_items, key=lambda x: x[1].rank) + # pick topk + topk = logprob_items_sorted[:topk_prompt_logprobs] + # record as (token_id, logprob) + topk_prompt_logprobs_list.append( + [(token_id, logprob.logprob) for token_id, logprob in topk] + ) + # collect response sequences + for seq_output in req_output.outputs: + seq = SampledSequence( + stop_reason="length" if seq_output.finish_reason == "length" else "stop", + tokens=seq_output.token_ids, + logprobs=[ + next(iter(logprob_dict.values())).logprob + for logprob_dict in seq_output.logprobs + ], + ) + sequences.append(seq) + return SampleResponse( + sequences=sequences, + prompt_logprobs=prompt_logprobs if include_prompt_logprobs else None, + topk_prompt_logprobs=( + topk_prompt_logprobs_list + if include_prompt_logprobs and topk_prompt_logprobs > 0 + else None + ), + ) + async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any: # Send the request to the LLM engine. self.request_id += 1