From bda3c8ba5b02171822f565eb538d6812128f42ca Mon Sep 17 00:00:00 2001 From: AssemblyAI Date: Tue, 3 Mar 2026 11:28:09 -0500 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 9cf018b501831b8b6343526095eeac6a88075fdf --- assemblyai/__version__.py | 2 +- assemblyai/streaming/v3/models.py | 4 ++ tests/unit/test_streaming.py | 90 +++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 1 deletion(-) diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index 3c5da5f..450ee12 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.53.0" +__version__ = "0.54.0" diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index 9132114..773ff0e 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -34,6 +34,7 @@ class TurnEvent(BaseModel): words: List[Word] language_code: Optional[str] = None language_confidence: Optional[float] = None + speaker_label: Optional[str] = None class BeginEvent(BaseModel): @@ -102,6 +103,7 @@ class SpeechModel(str, Enum): universal_streaming_multilingual = "universal-streaming-multilingual" universal_streaming_english = "universal-streaming-english" u3_rt_pro = "u3-rt-pro" + whisper_rt = "whisper-rt" u3_pro = "u3-pro" # Deprecated: Use u3_rt_pro instead def __str__(self): @@ -118,6 +120,8 @@ class StreamingParameters(StreamingSessionParameters): webhook_auth_header_name: Optional[str] = None webhook_auth_header_value: Optional[str] = None llm_gateway: Optional[LLMGatewayConfig] = None + speaker_labels: Optional[bool] = None + max_speakers: Optional[int] = None class UpdateConfiguration(StreamingSessionParameters): diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index 84027ce..63d2076 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -7,6 +7,7 @@ StreamingClient, StreamingClientOptions, StreamingParameters, + TurnEvent, ) @@ -262,3 +263,92 @@ def mocked_websocket_connect( assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] assert actual_open_timeout == 15 + + +def test_client_connect_with_speaker_labels(mocker: MockFixture): + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + + _disable_rw_threads(mocker) + + options = StreamingClientOptions(api_key="test", api_host="api.example.com") + client = StreamingClient(options) + + params = StreamingParameters( + sample_rate=16000, + speaker_labels=True, + max_speakers=3, + ) + + client.connect(params) + + assert "speaker_labels=True" in actual_url + assert "max_speakers=3" in actual_url + + +def test_client_connect_with_whisper_rt(mocker: MockFixture): + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + + _disable_rw_threads(mocker) + + options = StreamingClientOptions(api_key="test", api_host="api.example.com") + client = StreamingClient(options) + + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.whisper_rt, + ) + + client.connect(params) + + assert "speech_model=whisper-rt" in actual_url + + +def test_turn_event_with_speaker_label(): + data = { + "type": "Turn", + "turn_order": 1, + "turn_is_formatted": True, + "end_of_turn": False, + "transcript": "Hello world", + "end_of_turn_confidence": 0.85, + "words": [], + "speaker_label": "B", + } + event = TurnEvent.parse_obj(data) + assert event.speaker_label == "B" + + +def test_turn_event_without_speaker_label(): + data = { + "type": "Turn", + "turn_order": 1, + "turn_is_formatted": True, + "end_of_turn": False, + "transcript": "Hello world", + "end_of_turn_confidence": 0.85, + "words": [], + } + event = TurnEvent.parse_obj(data) + assert event.speaker_label is None