diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 386da063d..16b454b2b 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -6,6 +6,7 @@ from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.io import download_file +from astrbot.core.utils.media_utils import convert_audio_to_wav from astrbot.core.utils.tencent_record_helper import ( convert_to_pcm_wav, tencent_silk_to_wav, @@ -76,7 +77,18 @@ async def get_text(self, audio_url: str) -> str: if not os.path.exists(audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") - if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: + lower_audio_url = audio_url.lower() + + if lower_audio_url.endswith(".opus"): + temp_dir = get_astrbot_temp_path() + output_path = os.path.join( + temp_dir, + f"whisper_api_{uuid.uuid4().hex[:8]}.wav", + ) + logger.info("Converting opus file to wav using convert_audio_to_wav...") + await convert_audio_to_wav(audio_url, output_path) + audio_url = output_path + elif lower_audio_url.endswith(".amr") or lower_audio_url.endswith(".silk") or is_tencent: file_format = await self._get_audio_format(audio_url) # 判断是否需要转换 diff --git a/tests/test_whisper_api_source.py b/tests/test_whisper_api_source.py new file mode 100644 index 000000000..dd4a456a3 --- /dev/null +++ b/tests/test_whisper_api_source.py @@ -0,0 +1,72 @@ +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.provider.sources.whisper_api_source import ProviderOpenAIWhisperAPI + + +def _make_provider() -> ProviderOpenAIWhisperAPI: + provider = ProviderOpenAIWhisperAPI( + provider_config={ + "id": "test-whisper-api", + "type": "openai_whisper_api", + "model": "whisper-1", + "api_key": "test-key", + }, + provider_settings={}, + ) + provider.client = SimpleNamespace( + audio=SimpleNamespace( + transcriptions=SimpleNamespace( + create=AsyncMock(return_value=SimpleNamespace(text="transcribed text")) + ) + ), + close=AsyncMock(), + ) + return provider + + +@pytest.mark.asyncio +async def test_get_text_converts_opus_files_to_wav_before_transcription( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + provider = _make_provider() + opus_path = tmp_path / "voice.opus" + opus_path.write_bytes(b"fake opus data") + + conversions: list[tuple[str, str]] = [] + + async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None): + assert output_path is not None + conversions.append((audio_path, output_path)) + Path(output_path).write_bytes(b"fake wav data") + return output_path + + monkeypatch.setattr( + "astrbot.core.provider.sources.whisper_api_source.get_astrbot_temp_path", + lambda: str(tmp_path), + ) + monkeypatch.setattr( + "astrbot.core.provider.sources.whisper_api_source.convert_audio_to_wav", + fake_convert_audio_to_wav, + ) + + try: + result = await provider.get_text(str(opus_path)) + + assert result == "transcribed text" + assert conversions and conversions[0][0] == str(opus_path) + converted_path = Path(conversions[0][1]) + assert converted_path.suffix == ".wav" + assert not converted_path.exists() + + create_mock = provider.client.audio.transcriptions.create + create_mock.assert_awaited_once() + file_arg = create_mock.await_args.kwargs["file"] + assert file_arg[0] == "audio.wav" + assert file_arg[1].name.endswith(".wav") + file_arg[1].close() + finally: + await provider.terminate()