diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index 217b18925..bcdffb45b 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -45,18 +45,98 @@ def __init__( self.set_model(provider_config.get("model", "")) + @staticmethod + def _looks_like_text_payload(audio_bytes: bytes) -> bool: + sample = audio_bytes[:128].lstrip() + if not sample: + return False + if sample.startswith((b"{", b"[", b"<")): + return True + text_like = sum(1 for byte in sample if byte in b"\t\n\r" or 32 <= byte <= 126) + return text_like / len(sample) > 0.95 + + @classmethod + def _resolve_audio_extension(cls, content_type: str | None, audio_bytes: bytes) -> str: + normalized = (content_type or "").split(";", 1)[0].strip().lower() + extension_map = { + "audio/wav": ".wav", + "audio/wave": ".wav", + "audio/x-wav": ".wav", + "audio/mpeg": ".mp3", + "audio/mp3": ".mp3", + "audio/x-mpeg": ".mp3", + "audio/ogg": ".ogg", + "audio/opus": ".ogg", + "audio/flac": ".flac", + "audio/x-flac": ".flac", + "audio/aac": ".aac", + "audio/x-aac": ".aac", + "audio/webm": ".webm", + } + + if normalized: + if not normalized.startswith("audio/"): + preview = audio_bytes[:200].decode("utf-8", errors="ignore").strip() + preview = preview or "" + raise RuntimeError( + f"[OpenAI TTS] unexpected content-type {normalized!r} from TTS endpoint: {preview[:200]}" + ) + if normalized in extension_map: + return extension_map[normalized] + + header = audio_bytes[:16] + if header.startswith(b"RIFF") and audio_bytes[8:12] == b"WAVE": + return ".wav" + if header.startswith(b"ID3") or ( + len(audio_bytes) >= 2 + and audio_bytes[0] == 0xFF + and (audio_bytes[1] & 0xE0) == 0xE0 + ): + return ".mp3" + if header.startswith(b"OggS"): + return ".ogg" + if header.startswith(b"fLaC"): + return ".flac" + if header.startswith(b"\x1aE\xdf\xa3"): + return ".webm" + if header.startswith((b"\xff\xf1", b"\xff\xf9")): + return ".aac" + + if cls._looks_like_text_payload(audio_bytes): + preview = audio_bytes[:200].decode("utf-8", errors="ignore").strip() + preview = preview or "" + raise RuntimeError( + f"[OpenAI TTS] TTS endpoint returned a non-audio payload: {preview[:200]}" + ) + + return ".wav" + async def get_audio(self, text: str) -> str: temp_dir = get_astrbot_temp_path() - path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav") + os.makedirs(temp_dir, exist_ok=True) async with self.client.audio.speech.with_streaming_response.create( model=self.model_name, voice=self.voice, response_format="wav", input=text, ) as response: - with open(path, "wb") as f: - async for chunk in response.iter_bytes(chunk_size=1024): - f.write(chunk) + chunks = [] + async for chunk in response.iter_bytes(chunk_size=1024): + if chunk: + chunks.append(chunk) + + if not chunks: + raise RuntimeError("[OpenAI TTS] empty audio response") + + audio_bytes = b"".join(chunks) + content_type = None + if getattr(response, "headers", None): + content_type = response.headers.get("content-type") + + ext = self._resolve_audio_extension(content_type, audio_bytes) + path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}{ext}") + with open(path, "wb") as f: + f.write(audio_bytes) return path async def terminate(self): diff --git a/tests/test_openai_tts_api_source.py b/tests/test_openai_tts_api_source.py new file mode 100644 index 000000000..e18639366 --- /dev/null +++ b/tests/test_openai_tts_api_source.py @@ -0,0 +1,96 @@ +import asyncio +from pathlib import Path + +import pytest + +from astrbot.core.provider.sources import openai_tts_api_source +from astrbot.core.provider.sources.openai_tts_api_source import ProviderOpenAITTSAPI + + +class FakeStreamingResponse: + def __init__(self, chunks: list[bytes], headers: dict[str, str] | None = None): + self._chunks = chunks + self.headers = headers or {} + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def iter_bytes(self, chunk_size: int = 1024): + for chunk in self._chunks: + yield chunk + + +class FakeStreamingSpeech: + def __init__(self, response: FakeStreamingResponse): + self.response = response + self.calls: list[dict] = [] + + def create(self, **kwargs): + self.calls.append(kwargs) + return self.response + + +class FakeClient: + def __init__(self, response: FakeStreamingResponse): + self.audio = type( + "FakeAudio", + (), + { + "speech": type( + "FakeSpeech", + (), + {"with_streaming_response": FakeStreamingSpeech(response)}, + )() + }, + )() + self.closed = False + + async def close(self): + self.closed = True + + +def make_provider(monkeypatch, response: FakeStreamingResponse) -> ProviderOpenAITTSAPI: + fake_client = FakeClient(response) + monkeypatch.setattr(openai_tts_api_source, "AsyncOpenAI", lambda **kwargs: fake_client) + provider = ProviderOpenAITTSAPI( + { + "id": "openai_tts", + "type": "openai_tts_api", + "model": "gpt-4o-mini-tts", + "api_key": "test-key", + "openai-tts-voice": "alloy", + }, + {}, + ) + provider.client = fake_client + return provider + + +def test_get_audio_preserves_real_audio_extension(monkeypatch, tmp_path: Path): + response = FakeStreamingResponse( + chunks=[b"ID3", b"fake-mp3-audio"], + headers={"content-type": "audio/mpeg"}, + ) + provider = make_provider(monkeypatch, response) + monkeypatch.setattr(openai_tts_api_source, "get_astrbot_temp_path", lambda: str(tmp_path)) + + path = asyncio.run(provider.get_audio("hello")) + + assert path.endswith(".mp3") + assert Path(path).read_bytes() == b"ID3fake-mp3-audio" + assert provider.client.audio.speech.with_streaming_response.calls[0]["response_format"] == "wav" + + +def test_get_audio_raises_clear_error_for_non_audio_payload(monkeypatch, tmp_path: Path): + response = FakeStreamingResponse( + chunks=[b'{"error":"unsupported response_format"}'], + headers={"content-type": "application/json"}, + ) + provider = make_provider(monkeypatch, response) + monkeypatch.setattr(openai_tts_api_source, "get_astrbot_temp_path", lambda: str(tmp_path)) + + with pytest.raises(RuntimeError, match="unexpected content-type"): + asyncio.run(provider.get_audio("hello"))