Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 60 additions & 75 deletions streaming_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
DEFAULT_MAX_NEW_TOKENS = 4000
DEFAULT_REPETITION_PENALTY = 1.1
CODE_TOKEN_OFFSET = 128266
STOP_SEQUENCE = "<custom_token_2>"
AUDIO_SAMPLERATE = 24000
AUDIO_BITS_PER_SAMPLE = 16
AUDIO_CHANNELS = 1
Expand All @@ -36,8 +35,11 @@
STREAM_CHUNK_SIZE_GROUPS = 30
INITIAL_CHUNK_SIZE_GROUPS = 3

CODE_START_TOKEN_ID = 128257
# Official Orpheus-TTS inference wraps the prompt with fixed special-token IDs.
PROMPT_START_TOKEN_ID = 128259
PROMPT_END_TOKEN_IDS = [128009, 128260, 128261, 128257]
CODE_REMOVE_TOKEN_ID = 128258
STOP_TOKEN_IDS = [128258]
app = FastAPI()
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -68,8 +70,8 @@ def format_prompt_for_vllm_sync(prompt_text, voice="in_prompt"):
else:
full_text = prompt_text

start_token = torch.tensor([[128259]], dtype=torch.int64)
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
start_token = torch.tensor([[PROMPT_START_TOKEN_ID]], dtype=torch.int64)
end_tokens = torch.tensor([PROMPT_END_TOKEN_IDS], dtype=torch.int64)
input_ids = tokenizer(full_text, return_tensors="pt").input_ids
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
decoded_text = tokenizer.decode(modified_input_ids[0], skip_special_tokens=False)
Expand Down Expand Up @@ -196,101 +198,84 @@ async def generate_audio_stream(request: AudioRequest):
max_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=[STOP_SEQUENCE],
stream=True,
extra_body={'repetition_penalty': request.repetition_penalty},
extra_body={
'repetition_penalty': request.repetition_penalty,
'stop_token_ids': STOP_TOKEN_IDS,
},
)
response_stream = await client.completions.create(**stream_kwargs)

accumulated_text = ""
processed_code_count = 0
start_token_found = False
start_idx = -1
first_chunk_yielded = False

async for chunk in response_stream:
if chunk.choices:
chunk_text = chunk.choices[0].text or ""
accumulated_text += chunk_text
all_token_ids = await loop.run_in_executor(None, tokenize_sync, accumulated_text)
valid_raw_codes = [
token for token in all_token_ids
if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET
]

if not start_token_found:
try:
start_idx = all_token_ids.index(CODE_START_TOKEN_ID)
start_token_found = True
print(f"Code start token ({CODE_START_TOKEN_ID}) found at index {start_idx}.")
except ValueError:
continue
current_total_codes = len(valid_raw_codes)

if start_token_found:
potential_code_tokens = all_token_ids[start_idx + 1:]
if not first_chunk_yielded:
current_decode_chunk_size = INITIAL_CHUNK_SIZE_GROUPS * 7
print(f"Using initial chunk size: {current_decode_chunk_size} codes")
else:
current_decode_chunk_size = STREAM_CHUNK_SIZE_GROUPS * 7

valid_raw_codes = [
token for token in potential_code_tokens
if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET
]
if current_total_codes >= processed_code_count + current_decode_chunk_size:
codes_to_process_now_count = ((current_total_codes - processed_code_count) // current_decode_chunk_size) * current_decode_chunk_size
end_process_idx = processed_code_count + codes_to_process_now_count

current_total_codes = len(valid_raw_codes)
if end_process_idx > processed_code_count:
codes_to_process_raw = valid_raw_codes[processed_code_count:end_process_idx]
print(f"Processing codes from {processed_code_count} to {end_process_idx} ({len(codes_to_process_raw)} codes)")

if not first_chunk_yielded:
current_decode_chunk_size = INITIAL_CHUNK_SIZE_GROUPS * 7
print(f"Using initial chunk size: {current_decode_chunk_size} codes")
else:
current_decode_chunk_size = STREAM_CHUNK_SIZE_GROUPS * 7
codes_to_process = [t - CODE_TOKEN_OFFSET for t in codes_to_process_raw]

if current_total_codes >= processed_code_count + current_decode_chunk_size:
codes_to_process_now_count = ( (current_total_codes - processed_code_count) // current_decode_chunk_size ) * current_decode_chunk_size
end_process_idx = processed_code_count + codes_to_process_now_count
audio_hat = await loop.run_in_executor(
None, redistribute_codes_sync, codes_to_process
)

if end_process_idx > processed_code_count:
codes_to_process_raw = valid_raw_codes[processed_code_count : end_process_idx]
print(f"Processing codes from {processed_code_count} to {end_process_idx} ({len(codes_to_process_raw)} codes)")
pcm_bytes = convert_to_pcm16_bytes(audio_hat, fade_ms=50) # Apply fade here
if pcm_bytes:
print(f"Yielding {len(pcm_bytes)} bytes of audio data.")
yield pcm_bytes
first_chunk_yielded = True
else:
print("Warning: No PCM bytes generated for this chunk.")

codes_to_process = [t - CODE_TOKEN_OFFSET for t in codes_to_process_raw]

audio_hat = await loop.run_in_executor(
None, redistribute_codes_sync, codes_to_process
)

pcm_bytes = convert_to_pcm16_bytes(audio_hat, fade_ms=50) # Apply fade here
if pcm_bytes:
print(f"Yielding {len(pcm_bytes)} bytes of audio data.")
yield pcm_bytes
first_chunk_yielded = True
print("Warning: No PCM bytes generated for this chunk.")


processed_code_count = end_process_idx
processed_code_count = end_process_idx

print("Stream finished. Processing remaining codes.")
all_token_ids = await loop.run_in_executor(None, tokenize_sync, accumulated_text)

if start_token_found:
potential_code_tokens = all_token_ids[start_idx + 1:]
valid_raw_codes = [
token for token in potential_code_tokens
if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET
]
current_total_codes = len(valid_raw_codes)

if current_total_codes > processed_code_count:
remaining_codes_raw = valid_raw_codes[processed_code_count:]
num_remaining = len(remaining_codes_raw)
final_len = (num_remaining // 7) * 7

if final_len > 0:
codes_to_process = [t - CODE_TOKEN_OFFSET for t in remaining_codes_raw[:final_len]]
print(f"Processing final {len(codes_to_process)} codes.")

audio_hat = await loop.run_in_executor(
None, redistribute_codes_sync, codes_to_process
)
pcm_bytes = convert_to_pcm16_bytes(audio_hat, fade_ms=50)
if pcm_bytes:
print(f"Yielding final {len(pcm_bytes)} bytes of audio data.")
yield pcm_bytes
else:
print("Warning: Code start token never found in the entire response.")
valid_raw_codes = [
token for token in all_token_ids
if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET
]
current_total_codes = len(valid_raw_codes)

if current_total_codes > processed_code_count:
remaining_codes_raw = valid_raw_codes[processed_code_count:]
num_remaining = len(remaining_codes_raw)
final_len = (num_remaining // 7) * 7

if final_len > 0:
codes_to_process = [t - CODE_TOKEN_OFFSET for t in remaining_codes_raw[:final_len]]
print(f"Processing final {len(codes_to_process)} codes.")

audio_hat = await loop.run_in_executor(
None, redistribute_codes_sync, codes_to_process
)
pcm_bytes = convert_to_pcm16_bytes(audio_hat, fade_ms=50)
if pcm_bytes:
print(f"Yielding final {len(pcm_bytes)} bytes of audio data.")
yield pcm_bytes


print("Audio stream generation complete.")
Expand Down