From 2e792cbae383c3b9558f1b117378bcdba9c0a74e Mon Sep 17 00:00:00 2001 From: pan-x-c Date: Tue, 23 Dec 2025 12:33:59 +0000 Subject: [PATCH 1/4] add tinker sample interface --- pyproject.toml | 1 + tests/common/vllm_test.py | 48 +++++++++++++++++++++ trinity/common/models/vllm_model.py | 66 ++++++++++++++++++++++++++++- 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f7a8162bfe..ca1ba202a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "sortedcontainers", "word2number", "transformers", + "tinker", ] [project.scripts] diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 61fd03e675..4d38531d74 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1219,3 +1219,51 @@ async def test_generate(self): response.prompt_length, 40960 ) # If not long enough, please add more files to prompt self.assertGreater(response.logprobs.shape[0], 1000) + + +class TestTinkerAPI(RayUnittestBaseAysnc): + """Test the Tinker API integration with the vLLM engine.""" + + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.explorer.rollout_model.engine_type = "vllm" + self.config.explorer.rollout_model.engine_num = 1 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.enable_openai_api = True + + self.config.check_and_update() + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + + async def test_tinker_api(self): + from tinker import types + from transformers import AutoTokenizer + + engine = self.engines[0] + tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) + prompt = types.ModelInput.from_ints( + tokenizer.encode("How many r's are in the word strawberry?"), + ) + num_samples = 2 + topk_prompt_logprobs = 3 + response = await engine.sample.remote( + prompt=prompt, + num_samples=num_samples, + sampling_params=types.SamplingParams(temperature=0.7, max_tokens=16), + include_prompt_logprobs=True, + topk_prompt_logprobs=topk_prompt_logprobs, + ) + self.assertEqual(len(response.sequences), num_samples) + for sequence in response.sequences: + self.assertEqual(len(sequence.tokens), len(sequence.logprobs)) + self.assertEqual(sequence.stop_reason, "stop") + self.assertEqual(len(response.prompt_logprobs), len(prompt.to_ints())) + self.assertIsNone(response.prompt_logprobs[0]) + self.assertEqual(len(response.topk_prompt_logprobs), len(prompt.to_ints())) + self.assertIsNone(response.topk_prompt_logprobs[0]) + for topk_logprobs in response.topk_prompt_logprobs[1:]: + self.assertIsNotNone(topk_logprobs) + self.assertEqual(len(topk_logprobs), topk_prompt_logprobs) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index c6f85b48f8..1443dfc1b0 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -3,13 +3,14 @@ import asyncio import os from collections import defaultdict -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import ray import torch from packaging.version import parse as parse_version from PIL import Image +from tinker import types from transformers import AutoProcessor from trinity.common.config import InferenceModelConfig @@ -402,6 +403,69 @@ async def logprobs( # type: ignore [override] dtype=torch.float32, ) + async def sample( + self, + prompt: types.ModelInput, + num_samples: int, + sampling_params: types.SamplingParams, + include_prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0, + lora_request=None, + ) -> types.SampleResponse: + """Tinker compatible sampling interface.""" + params = { + "max_tokens": sampling_params.max_tokens, + "seed": sampling_params.seed, + "stop": sampling_params.stop, + "top_k": sampling_params.top_k, + "top_p": sampling_params.top_p, + "temperature": sampling_params.temperature, + "n": num_samples, + "prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None), + # in vLLM, 0 means only return the chosen token's logprob + "logprobs": 0, + } + req_output = await self._generate_internal( + prompt={"prompt_token_ids": prompt.to_ints()}, + lora_request=lora_request, + **params, + ) + sequences = [] + topk_prompt_logprobs_list: List[Optional[List[Tuple[int, float]]]] = [None] + prompt_logprobs: List[Optional[float]] = [None] + + # collect prompt logprobs + for logprob_dict in req_output.prompt_logprobs[1:]: + prompt_logprobs.append(list(logprob_dict.values())[0].logprob) + if topk_prompt_logprobs > 0: + # collect top-k prompt logprobs + # logprob_dict: {token_id: Logprob(logprob, rank, ...), ...} + logprob_items = list(logprob_dict.items()) + # sort by Logprob.rank + logprob_items_sorted = sorted(logprob_items, key=lambda x: x[1].rank) + # pick topk + topk = logprob_items_sorted[:topk_prompt_logprobs] + # record as (token_id, logprob) + topk_prompt_logprobs_list.append( + [(token_id, logprob.logprob) for token_id, logprob in topk] + ) + # collect response sequences + for seq_output in req_output.outputs: + seq = types.SampledSequence( + stop_reason="length" if seq_output.stop_reason == "length" else "stop", + tokens=seq_output.token_ids, + logprobs=[ + list(logprob_dict.values())[0].logprob for logprob_dict in seq_output.logprobs + ], + ) + sequences.append(seq) + + return types.SampleResponse( + sequences=sequences, + prompt_logprobs=prompt_logprobs, + topk_prompt_logprobs=topk_prompt_logprobs_list, + ) + async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any: # Send the request to the LLM engine. self.request_id += 1 From 8f387adeb106d426f74ea3bc7ead3ead85cc7861 Mon Sep 17 00:00:00 2001 From: pan-x-c Date: Tue, 23 Dec 2025 13:16:19 +0000 Subject: [PATCH 2/4] add more tests --- tests/common/vllm_test.py | 47 ++++++++++++++++++++++-- trinity/common/models/vllm_model.py | 55 ++++++++++++++++------------- 2 files changed, 74 insertions(+), 28 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 4d38531d74..b731e51ea2 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1244,22 +1244,52 @@ async def test_tinker_api(self): engine = self.engines[0] tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, + ] + result_dict = tokenizer.apply_chat_template( + messages, + chat_template=CHAT_TEMPLATE, + add_generation_prompt=False, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + return_assistant_tokens_mask=True, + return_dict=True, + ) prompt = types.ModelInput.from_ints( - tokenizer.encode("How many r's are in the word strawberry?"), + result_dict["input_ids"][0].tolist(), ) + # sample api without prompt logprobs + num_samples = 4 + response = await engine.sample.remote( + prompt=prompt, + num_samples=num_samples, + sampling_params=types.SamplingParams(temperature=0.7), # no limit on length + ) + self.assertEqual(len(response.sequences), num_samples) + for sequence in response.sequences: + print("response length:", len(sequence.tokens)) + self.assertEqual(len(sequence.tokens), len(sequence.logprobs)) + self.assertEqual(sequence.stop_reason, "stop") + self.assertIsNone(response.prompt_logprobs) + self.assertIsNone(response.topk_prompt_logprobs) + # sample api with prompt logprobs num_samples = 2 topk_prompt_logprobs = 3 response = await engine.sample.remote( prompt=prompt, num_samples=num_samples, - sampling_params=types.SamplingParams(temperature=0.7, max_tokens=16), + sampling_params=types.SamplingParams(temperature=0.7, max_tokens=8), include_prompt_logprobs=True, topk_prompt_logprobs=topk_prompt_logprobs, ) self.assertEqual(len(response.sequences), num_samples) for sequence in response.sequences: self.assertEqual(len(sequence.tokens), len(sequence.logprobs)) - self.assertEqual(sequence.stop_reason, "stop") + self.assertEqual(sequence.stop_reason, "length") self.assertEqual(len(response.prompt_logprobs), len(prompt.to_ints())) self.assertIsNone(response.prompt_logprobs[0]) self.assertEqual(len(response.topk_prompt_logprobs), len(prompt.to_ints())) @@ -1267,3 +1297,14 @@ async def test_tinker_api(self): for topk_logprobs in response.topk_prompt_logprobs[1:]: self.assertIsNotNone(topk_logprobs) self.assertEqual(len(topk_logprobs), topk_prompt_logprobs) + # compute_logprob api + response = await engine.sample.remote( + prompt=prompt, + num_samples=1, + sampling_params=types.SamplingParams(max_tokens=1), + include_prompt_logprobs=True, + ) + self.assertEqual(len(response.sequences), 1) + self.assertEqual(response.sequences[0].stop_reason, "length") + self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs)) + self.assertIsNone(response.topk_prompt_logprobs) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 1443dfc1b0..146b034ff6 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -414,17 +414,18 @@ async def sample( ) -> types.SampleResponse: """Tinker compatible sampling interface.""" params = { - "max_tokens": sampling_params.max_tokens, - "seed": sampling_params.seed, - "stop": sampling_params.stop, - "top_k": sampling_params.top_k, - "top_p": sampling_params.top_p, - "temperature": sampling_params.temperature, + "max_tokens": sampling_params.max_tokens or self.config.max_response_tokens, + "seed": sampling_params.seed or self.config.seed, + "top_k": sampling_params.top_k or self.config.top_k, + "top_p": sampling_params.top_p or self.config.top_p, + "temperature": sampling_params.temperature or self.config.temperature, "n": num_samples, "prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None), # in vLLM, 0 means only return the chosen token's logprob "logprobs": 0, } + if sampling_params.stop is not None: + params["stop"] = sampling_params.stop req_output = await self._generate_internal( prompt={"prompt_token_ids": prompt.to_ints()}, lora_request=lora_request, @@ -435,35 +436,39 @@ async def sample( prompt_logprobs: List[Optional[float]] = [None] # collect prompt logprobs - for logprob_dict in req_output.prompt_logprobs[1:]: - prompt_logprobs.append(list(logprob_dict.values())[0].logprob) - if topk_prompt_logprobs > 0: - # collect top-k prompt logprobs - # logprob_dict: {token_id: Logprob(logprob, rank, ...), ...} - logprob_items = list(logprob_dict.items()) - # sort by Logprob.rank - logprob_items_sorted = sorted(logprob_items, key=lambda x: x[1].rank) - # pick topk - topk = logprob_items_sorted[:topk_prompt_logprobs] - # record as (token_id, logprob) - topk_prompt_logprobs_list.append( - [(token_id, logprob.logprob) for token_id, logprob in topk] - ) + if include_prompt_logprobs: + for logprob_dict in req_output.prompt_logprobs[1:]: + prompt_logprobs.append(list(logprob_dict.values())[0].logprob) + if topk_prompt_logprobs > 0: + # collect top-k prompt logprobs + # logprob_dict: {token_id: Logprob(logprob, rank, ...), ...} + logprob_items = list(logprob_dict.items()) + # sort by Logprob.rank + logprob_items_sorted = sorted(logprob_items, key=lambda x: x[1].rank) + # pick topk + topk = logprob_items_sorted[:topk_prompt_logprobs] + # record as (token_id, logprob) + topk_prompt_logprobs_list.append( + [(token_id, logprob.logprob) for token_id, logprob in topk] + ) # collect response sequences for seq_output in req_output.outputs: seq = types.SampledSequence( - stop_reason="length" if seq_output.stop_reason == "length" else "stop", + stop_reason="length" if seq_output.finish_reason == "length" else "stop", tokens=seq_output.token_ids, logprobs=[ list(logprob_dict.values())[0].logprob for logprob_dict in seq_output.logprobs ], ) sequences.append(seq) - return types.SampleResponse( sequences=sequences, - prompt_logprobs=prompt_logprobs, - topk_prompt_logprobs=topk_prompt_logprobs_list, + prompt_logprobs=prompt_logprobs if include_prompt_logprobs else None, + topk_prompt_logprobs=( + topk_prompt_logprobs_list + if include_prompt_logprobs and topk_prompt_logprobs > 0 + else None + ), ) async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any: @@ -511,7 +516,7 @@ async def convert_messages_to_experience( if len(token_ids) > self.config.max_model_len - 1: truncate_status = "response_truncated" self.logger.warning( - f"Warning: {len(token_ids) = } exceeds the length limit {self.config.max_model_len-1 = }" + f"Warning: {len(token_ids)=} exceeds the length limit {self.config.max_model_len - 1=}" ) token_ids = token_ids[: self.config.max_model_len - 1] action_mask = action_mask[: self.config.max_model_len - 1] From f70318b0549f787bf5dbd1ffba3022d8bd6c83bd Mon Sep 17 00:00:00 2001 From: pan-x-c Date: Thu, 25 Dec 2025 11:25:28 +0000 Subject: [PATCH 3/4] fix comments --- tests/common/vllm_test.py | 1 - trinity/common/models/vllm_model.py | 21 +++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index b731e51ea2..7b545d52f3 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1271,7 +1271,6 @@ async def test_tinker_api(self): ) self.assertEqual(len(response.sequences), num_samples) for sequence in response.sequences: - print("response length:", len(sequence.tokens)) self.assertEqual(len(sequence.tokens), len(sequence.logprobs)) self.assertEqual(sequence.stop_reason, "stop") self.assertIsNone(response.prompt_logprobs) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 2630cb2b0d..1603562cf4 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -410,15 +410,17 @@ async def sample( sampling_params: types.SamplingParams, include_prompt_logprobs: bool = False, topk_prompt_logprobs: int = 0, - lora_request=None, + lora_request: Optional[Any] = None, ) -> types.SampleResponse: """Tinker compatible sampling interface.""" params = { - "max_tokens": sampling_params.max_tokens or self.config.max_response_tokens, - "seed": sampling_params.seed or self.config.seed, - "top_k": sampling_params.top_k or self.config.top_k, - "top_p": sampling_params.top_p or self.config.top_p, - "temperature": sampling_params.temperature or self.config.temperature, + "max_tokens": sampling_params.max_tokens + if sampling_params.max_tokens is not None + else self.config.max_response_tokens, + "seed": sampling_params.seed if sampling_params.seed is not None else self.config.seed, + "top_k": sampling_params.top_k, + "top_p": sampling_params.top_p, + "temperature": sampling_params.temperature, "n": num_samples, "prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None), # in vLLM, 0 means only return the chosen token's logprob @@ -432,13 +434,15 @@ async def sample( **params, ) sequences = [] + # vLLM's prompt_logprobs output does not include a value for the first token. + # Initialize with [None] to align with the prompt tokens. topk_prompt_logprobs_list: List[Optional[List[Tuple[int, float]]]] = [None] prompt_logprobs: List[Optional[float]] = [None] # collect prompt logprobs if include_prompt_logprobs: for logprob_dict in req_output.prompt_logprobs[1:]: - prompt_logprobs.append(list(logprob_dict.values())[0].logprob) + prompt_logprobs.append(next(iter(logprob_dict.values())).logprob) if topk_prompt_logprobs > 0: # collect top-k prompt logprobs # logprob_dict: {token_id: Logprob(logprob, rank, ...), ...} @@ -457,7 +461,8 @@ async def sample( stop_reason="length" if seq_output.finish_reason == "length" else "stop", tokens=seq_output.token_ids, logprobs=[ - list(logprob_dict.values())[0].logprob for logprob_dict in seq_output.logprobs + next(iter(logprob_dict.values())).logprob + for logprob_dict in seq_output.logprobs ], ) sequences.append(seq) From 87cb9e56a2225cf98147279eb3bc451b386821f4 Mon Sep 17 00:00:00 2001 From: pan-x-c Date: Thu, 25 Dec 2025 11:49:22 +0000 Subject: [PATCH 4/4] fix import --- trinity/common/models/vllm_model.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 1603562cf4..43e7c852ac 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -10,7 +10,6 @@ import torch from packaging.version import parse as parse_version from PIL import Image -from tinker import types from transformers import AutoProcessor from trinity.common.config import InferenceModelConfig @@ -405,14 +404,27 @@ async def logprobs( # type: ignore [override] async def sample( self, - prompt: types.ModelInput, + prompt: Any, num_samples: int, - sampling_params: types.SamplingParams, + sampling_params: Any, include_prompt_logprobs: bool = False, topk_prompt_logprobs: int = 0, lora_request: Optional[Any] = None, - ) -> types.SampleResponse: - """Tinker compatible sampling interface.""" + ) -> Any: + """Tinker compatible sampling interface. + + Args: + prompt (ModelInput): The input prompt. + num_samples (int): The number of samples to generate. + sampling_params (SamplingParams): The sampling parameters. + include_prompt_logprobs (bool): Whether to include prompt logprobs. + topk_prompt_logprobs (int): The top-k prompt logprobs to include. + lora_request (LoRARequest, optional): The LoRA request. Defaults to None. + Returns: + SampleResponse: The sample response. + """ + from tinker.types import SampledSequence, SampleResponse + params = { "max_tokens": sampling_params.max_tokens if sampling_params.max_tokens is not None @@ -457,7 +469,7 @@ async def sample( ) # collect response sequences for seq_output in req_output.outputs: - seq = types.SampledSequence( + seq = SampledSequence( stop_reason="length" if seq_output.finish_reason == "length" else "stop", tokens=seq_output.token_ids, logprobs=[ @@ -466,7 +478,7 @@ async def sample( ], ) sequences.append(seq) - return types.SampleResponse( + return SampleResponse( sequences=sequences, prompt_logprobs=prompt_logprobs if include_prompt_logprobs else None, topk_prompt_logprobs=(