From ab8a5be743063ce05a98f7aa4caa2a45f358cd07 Mon Sep 17 00:00:00 2001 From: bodza Date: Mon, 30 Mar 2026 21:59:27 +0200 Subject: [PATCH] Align FastAPI with official Orpheus prompt setup --- streaming_api_server.py | 135 ++++++++++++++++++---------------------- 1 file changed, 60 insertions(+), 75 deletions(-) diff --git a/streaming_api_server.py b/streaming_api_server.py index f229257..0b040b7 100644 --- a/streaming_api_server.py +++ b/streaming_api_server.py @@ -27,7 +27,6 @@ DEFAULT_MAX_NEW_TOKENS = 4000 DEFAULT_REPETITION_PENALTY = 1.1 CODE_TOKEN_OFFSET = 128266 -STOP_SEQUENCE = "" AUDIO_SAMPLERATE = 24000 AUDIO_BITS_PER_SAMPLE = 16 AUDIO_CHANNELS = 1 @@ -36,8 +35,11 @@ STREAM_CHUNK_SIZE_GROUPS = 30 INITIAL_CHUNK_SIZE_GROUPS = 3 -CODE_START_TOKEN_ID = 128257 +# Official Orpheus-TTS inference wraps the prompt with fixed special-token IDs. +PROMPT_START_TOKEN_ID = 128259 +PROMPT_END_TOKEN_IDS = [128009, 128260, 128261, 128257] CODE_REMOVE_TOKEN_ID = 128258 +STOP_TOKEN_IDS = [128258] app = FastAPI() SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -68,8 +70,8 @@ def format_prompt_for_vllm_sync(prompt_text, voice="in_prompt"): else: full_text = prompt_text - start_token = torch.tensor([[128259]], dtype=torch.int64) - end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) + start_token = torch.tensor([[PROMPT_START_TOKEN_ID]], dtype=torch.int64) + end_tokens = torch.tensor([PROMPT_END_TOKEN_IDS], dtype=torch.int64) input_ids = tokenizer(full_text, return_tensors="pt").input_ids modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) decoded_text = tokenizer.decode(modified_input_ids[0], skip_special_tokens=False) @@ -196,16 +198,16 @@ async def generate_audio_stream(request: AudioRequest): max_tokens=request.max_new_tokens, temperature=request.temperature, top_p=request.top_p, - stop=[STOP_SEQUENCE], stream=True, - extra_body={'repetition_penalty': request.repetition_penalty}, + extra_body={ + 'repetition_penalty': request.repetition_penalty, + 'stop_token_ids': STOP_TOKEN_IDS, + }, ) response_stream = await client.completions.create(**stream_kwargs) accumulated_text = "" processed_code_count = 0 - start_token_found = False - start_idx = -1 first_chunk_yielded = False async for chunk in response_stream: @@ -213,84 +215,67 @@ async def generate_audio_stream(request: AudioRequest): chunk_text = chunk.choices[0].text or "" accumulated_text += chunk_text all_token_ids = await loop.run_in_executor(None, tokenize_sync, accumulated_text) + valid_raw_codes = [ + token for token in all_token_ids + if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET + ] - if not start_token_found: - try: - start_idx = all_token_ids.index(CODE_START_TOKEN_ID) - start_token_found = True - print(f"Code start token ({CODE_START_TOKEN_ID}) found at index {start_idx}.") - except ValueError: - continue + current_total_codes = len(valid_raw_codes) - if start_token_found: - potential_code_tokens = all_token_ids[start_idx + 1:] + if not first_chunk_yielded: + current_decode_chunk_size = INITIAL_CHUNK_SIZE_GROUPS * 7 + print(f"Using initial chunk size: {current_decode_chunk_size} codes") + else: + current_decode_chunk_size = STREAM_CHUNK_SIZE_GROUPS * 7 - valid_raw_codes = [ - token for token in potential_code_tokens - if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET - ] + if current_total_codes >= processed_code_count + current_decode_chunk_size: + codes_to_process_now_count = ((current_total_codes - processed_code_count) // current_decode_chunk_size) * current_decode_chunk_size + end_process_idx = processed_code_count + codes_to_process_now_count - current_total_codes = len(valid_raw_codes) + if end_process_idx > processed_code_count: + codes_to_process_raw = valid_raw_codes[processed_code_count:end_process_idx] + print(f"Processing codes from {processed_code_count} to {end_process_idx} ({len(codes_to_process_raw)} codes)") - if not first_chunk_yielded: - current_decode_chunk_size = INITIAL_CHUNK_SIZE_GROUPS * 7 - print(f"Using initial chunk size: {current_decode_chunk_size} codes") - else: - current_decode_chunk_size = STREAM_CHUNK_SIZE_GROUPS * 7 + codes_to_process = [t - CODE_TOKEN_OFFSET for t in codes_to_process_raw] - if current_total_codes >= processed_code_count + current_decode_chunk_size: - codes_to_process_now_count = ( (current_total_codes - processed_code_count) // current_decode_chunk_size ) * current_decode_chunk_size - end_process_idx = processed_code_count + codes_to_process_now_count + audio_hat = await loop.run_in_executor( + None, redistribute_codes_sync, codes_to_process + ) - if end_process_idx > processed_code_count: - codes_to_process_raw = valid_raw_codes[processed_code_count : end_process_idx] - print(f"Processing codes from {processed_code_count} to {end_process_idx} ({len(codes_to_process_raw)} codes)") + pcm_bytes = convert_to_pcm16_bytes(audio_hat, fade_ms=50) # Apply fade here + if pcm_bytes: + print(f"Yielding {len(pcm_bytes)} bytes of audio data.") + yield pcm_bytes + first_chunk_yielded = True + else: + print("Warning: No PCM bytes generated for this chunk.") - codes_to_process = [t - CODE_TOKEN_OFFSET for t in codes_to_process_raw] - - audio_hat = await loop.run_in_executor( - None, redistribute_codes_sync, codes_to_process - ) - - pcm_bytes = convert_to_pcm16_bytes(audio_hat, fade_ms=50) # Apply fade here - if pcm_bytes: - print(f"Yielding {len(pcm_bytes)} bytes of audio data.") - yield pcm_bytes - first_chunk_yielded = True - print("Warning: No PCM bytes generated for this chunk.") - - - processed_code_count = end_process_idx + processed_code_count = end_process_idx print("Stream finished. Processing remaining codes.") all_token_ids = await loop.run_in_executor(None, tokenize_sync, accumulated_text) - - if start_token_found: - potential_code_tokens = all_token_ids[start_idx + 1:] - valid_raw_codes = [ - token for token in potential_code_tokens - if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET - ] - current_total_codes = len(valid_raw_codes) - - if current_total_codes > processed_code_count: - remaining_codes_raw = valid_raw_codes[processed_code_count:] - num_remaining = len(remaining_codes_raw) - final_len = (num_remaining // 7) * 7 - - if final_len > 0: - codes_to_process = [t - CODE_TOKEN_OFFSET for t in remaining_codes_raw[:final_len]] - print(f"Processing final {len(codes_to_process)} codes.") - - audio_hat = await loop.run_in_executor( - None, redistribute_codes_sync, codes_to_process - ) - pcm_bytes = convert_to_pcm16_bytes(audio_hat, fade_ms=50) - if pcm_bytes: - print(f"Yielding final {len(pcm_bytes)} bytes of audio data.") - yield pcm_bytes - else: - print("Warning: Code start token never found in the entire response.") + valid_raw_codes = [ + token for token in all_token_ids + if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET + ] + current_total_codes = len(valid_raw_codes) + + if current_total_codes > processed_code_count: + remaining_codes_raw = valid_raw_codes[processed_code_count:] + num_remaining = len(remaining_codes_raw) + final_len = (num_remaining // 7) * 7 + + if final_len > 0: + codes_to_process = [t - CODE_TOKEN_OFFSET for t in remaining_codes_raw[:final_len]] + print(f"Processing final {len(codes_to_process)} codes.") + + audio_hat = await loop.run_in_executor( + None, redistribute_codes_sync, codes_to_process + ) + pcm_bytes = convert_to_pcm16_bytes(audio_hat, fade_ms=50) + if pcm_bytes: + print(f"Yielding final {len(pcm_bytes)} bytes of audio data.") + yield pcm_bytes print("Audio stream generation complete.")