Skip to content

Commit bda3c8b

Browse files
AssemblyAIzkleb-aai
authored andcommitted
Project import generated by Copybara.
GitOrigin-RevId: 9cf018b501831b8b6343526095eeac6a88075fdf
1 parent 99fd42e commit bda3c8b

3 files changed

Lines changed: 95 additions & 1 deletion

File tree

assemblyai/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.53.0"
1+
__version__ = "0.54.0"

assemblyai/streaming/v3/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class TurnEvent(BaseModel):
3434
words: List[Word]
3535
language_code: Optional[str] = None
3636
language_confidence: Optional[float] = None
37+
speaker_label: Optional[str] = None
3738

3839

3940
class BeginEvent(BaseModel):
@@ -102,6 +103,7 @@ class SpeechModel(str, Enum):
102103
universal_streaming_multilingual = "universal-streaming-multilingual"
103104
universal_streaming_english = "universal-streaming-english"
104105
u3_rt_pro = "u3-rt-pro"
106+
whisper_rt = "whisper-rt"
105107
u3_pro = "u3-pro" # Deprecated: Use u3_rt_pro instead
106108

107109
def __str__(self):
@@ -118,6 +120,8 @@ class StreamingParameters(StreamingSessionParameters):
118120
webhook_auth_header_name: Optional[str] = None
119121
webhook_auth_header_value: Optional[str] = None
120122
llm_gateway: Optional[LLMGatewayConfig] = None
123+
speaker_labels: Optional[bool] = None
124+
max_speakers: Optional[int] = None
121125

122126

123127
class UpdateConfiguration(StreamingSessionParameters):

tests/unit/test_streaming.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
StreamingClient,
88
StreamingClientOptions,
99
StreamingParameters,
10+
TurnEvent,
1011
)
1112

1213

@@ -262,3 +263,92 @@ def mocked_websocket_connect(
262263
assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"]
263264

264265
assert actual_open_timeout == 15
266+
267+
268+
def test_client_connect_with_speaker_labels(mocker: MockFixture):
269+
actual_url = None
270+
271+
def mocked_websocket_connect(
272+
url: str, additional_headers: dict, open_timeout: float
273+
):
274+
nonlocal actual_url
275+
actual_url = url
276+
277+
mocker.patch(
278+
"assemblyai.streaming.v3.client.websocket_connect",
279+
new=mocked_websocket_connect,
280+
)
281+
282+
_disable_rw_threads(mocker)
283+
284+
options = StreamingClientOptions(api_key="test", api_host="api.example.com")
285+
client = StreamingClient(options)
286+
287+
params = StreamingParameters(
288+
sample_rate=16000,
289+
speaker_labels=True,
290+
max_speakers=3,
291+
)
292+
293+
client.connect(params)
294+
295+
assert "speaker_labels=True" in actual_url
296+
assert "max_speakers=3" in actual_url
297+
298+
299+
def test_client_connect_with_whisper_rt(mocker: MockFixture):
300+
actual_url = None
301+
302+
def mocked_websocket_connect(
303+
url: str, additional_headers: dict, open_timeout: float
304+
):
305+
nonlocal actual_url
306+
actual_url = url
307+
308+
mocker.patch(
309+
"assemblyai.streaming.v3.client.websocket_connect",
310+
new=mocked_websocket_connect,
311+
)
312+
313+
_disable_rw_threads(mocker)
314+
315+
options = StreamingClientOptions(api_key="test", api_host="api.example.com")
316+
client = StreamingClient(options)
317+
318+
params = StreamingParameters(
319+
sample_rate=16000,
320+
speech_model=SpeechModel.whisper_rt,
321+
)
322+
323+
client.connect(params)
324+
325+
assert "speech_model=whisper-rt" in actual_url
326+
327+
328+
def test_turn_event_with_speaker_label():
329+
data = {
330+
"type": "Turn",
331+
"turn_order": 1,
332+
"turn_is_formatted": True,
333+
"end_of_turn": False,
334+
"transcript": "Hello world",
335+
"end_of_turn_confidence": 0.85,
336+
"words": [],
337+
"speaker_label": "B",
338+
}
339+
event = TurnEvent.parse_obj(data)
340+
assert event.speaker_label == "B"
341+
342+
343+
def test_turn_event_without_speaker_label():
344+
data = {
345+
"type": "Turn",
346+
"turn_order": 1,
347+
"turn_is_formatted": True,
348+
"end_of_turn": False,
349+
"transcript": "Hello world",
350+
"end_of_turn_confidence": 0.85,
351+
"words": [],
352+
}
353+
event = TurnEvent.parse_obj(data)
354+
assert event.speaker_label is None

0 commit comments

Comments
 (0)