Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions astrbot/core/provider/sources/openai_tts_api_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,98 @@ def __init__(

self.set_model(provider_config.get("model", ""))

@staticmethod
def _looks_like_text_payload(audio_bytes: bytes) -> bool:
sample = audio_bytes[:128].lstrip()
if not sample:
return False
if sample.startswith((b"{", b"[", b"<")):
return True
text_like = sum(1 for byte in sample if byte in b"\t\n\r" or 32 <= byte <= 126)
return text_like / len(sample) > 0.95

@classmethod
def _resolve_audio_extension(cls, content_type: str | None, audio_bytes: bytes) -> str:
normalized = (content_type or "").split(";", 1)[0].strip().lower()
extension_map = {
"audio/wav": ".wav",
"audio/wave": ".wav",
"audio/x-wav": ".wav",
"audio/mpeg": ".mp3",
"audio/mp3": ".mp3",
"audio/x-mpeg": ".mp3",
"audio/ogg": ".ogg",
"audio/opus": ".ogg",
"audio/flac": ".flac",
"audio/x-flac": ".flac",
"audio/aac": ".aac",
"audio/x-aac": ".aac",
"audio/webm": ".webm",
}

if normalized:
if not normalized.startswith("audio/"):
preview = audio_bytes[:200].decode("utf-8", errors="ignore").strip()
preview = preview or "<empty response>"
raise RuntimeError(
f"[OpenAI TTS] unexpected content-type {normalized!r} from TTS endpoint: {preview[:200]}"
)
if normalized in extension_map:
return extension_map[normalized]

header = audio_bytes[:16]
if header.startswith(b"RIFF") and audio_bytes[8:12] == b"WAVE":
return ".wav"
if header.startswith(b"ID3") or (
len(audio_bytes) >= 2
and audio_bytes[0] == 0xFF
and (audio_bytes[1] & 0xE0) == 0xE0
):
return ".mp3"
if header.startswith(b"OggS"):
return ".ogg"
if header.startswith(b"fLaC"):
return ".flac"
if header.startswith(b"\x1aE\xdf\xa3"):
return ".webm"
if header.startswith((b"\xff\xf1", b"\xff\xf9")):
return ".aac"

if cls._looks_like_text_payload(audio_bytes):
preview = audio_bytes[:200].decode("utf-8", errors="ignore").strip()
preview = preview or "<empty response>"
raise RuntimeError(
f"[OpenAI TTS] TTS endpoint returned a non-audio payload: {preview[:200]}"
)

return ".wav"

async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
os.makedirs(temp_dir, exist_ok=True)
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
voice=self.voice,
response_format="wav",
input=text,
) as response:
with open(path, "wb") as f:
async for chunk in response.iter_bytes(chunk_size=1024):
f.write(chunk)
chunks = []
async for chunk in response.iter_bytes(chunk_size=1024):
if chunk:
chunks.append(chunk)

if not chunks:
raise RuntimeError("[OpenAI TTS] empty audio response")

audio_bytes = b"".join(chunks)
content_type = None
if getattr(response, "headers", None):
content_type = response.headers.get("content-type")

ext = self._resolve_audio_extension(content_type, audio_bytes)
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}{ext}")
with open(path, "wb") as f:
f.write(audio_bytes)
return path

async def terminate(self):
Expand Down
96 changes: 96 additions & 0 deletions tests/test_openai_tts_api_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import asyncio
from pathlib import Path

import pytest

from astrbot.core.provider.sources import openai_tts_api_source
from astrbot.core.provider.sources.openai_tts_api_source import ProviderOpenAITTSAPI


class FakeStreamingResponse:
def __init__(self, chunks: list[bytes], headers: dict[str, str] | None = None):
self._chunks = chunks
self.headers = headers or {}

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
return False

async def iter_bytes(self, chunk_size: int = 1024):
for chunk in self._chunks:
yield chunk


class FakeStreamingSpeech:
def __init__(self, response: FakeStreamingResponse):
self.response = response
self.calls: list[dict] = []

def create(self, **kwargs):
self.calls.append(kwargs)
return self.response


class FakeClient:
def __init__(self, response: FakeStreamingResponse):
self.audio = type(
"FakeAudio",
(),
{
"speech": type(
"FakeSpeech",
(),
{"with_streaming_response": FakeStreamingSpeech(response)},
)()
},
)()
self.closed = False

async def close(self):
self.closed = True


def make_provider(monkeypatch, response: FakeStreamingResponse) -> ProviderOpenAITTSAPI:
fake_client = FakeClient(response)
monkeypatch.setattr(openai_tts_api_source, "AsyncOpenAI", lambda **kwargs: fake_client)
provider = ProviderOpenAITTSAPI(
{
"id": "openai_tts",
"type": "openai_tts_api",
"model": "gpt-4o-mini-tts",
"api_key": "test-key",
"openai-tts-voice": "alloy",
},
{},
)
provider.client = fake_client
return provider


def test_get_audio_preserves_real_audio_extension(monkeypatch, tmp_path: Path):
response = FakeStreamingResponse(
chunks=[b"ID3", b"fake-mp3-audio"],
headers={"content-type": "audio/mpeg"},
)
provider = make_provider(monkeypatch, response)
monkeypatch.setattr(openai_tts_api_source, "get_astrbot_temp_path", lambda: str(tmp_path))

path = asyncio.run(provider.get_audio("hello"))

assert path.endswith(".mp3")
assert Path(path).read_bytes() == b"ID3fake-mp3-audio"
assert provider.client.audio.speech.with_streaming_response.calls[0]["response_format"] == "wav"


def test_get_audio_raises_clear_error_for_non_audio_payload(monkeypatch, tmp_path: Path):
response = FakeStreamingResponse(
chunks=[b'{"error":"unsupported response_format"}'],
headers={"content-type": "application/json"},
)
provider = make_provider(monkeypatch, response)
monkeypatch.setattr(openai_tts_api_source, "get_astrbot_temp_path", lambda: str(tmp_path))

with pytest.raises(RuntimeError, match="unexpected content-type"):
asyncio.run(provider.get_audio("hello"))