diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 567a2e147..aec571f9f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -458,6 +458,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An result = {"text": {"text": guard_text["text"], "qualifiers": guard_text["qualifiers"]}} return {"guardContent": result} + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_AudioBlock.html + if "audio" in content: + audio = content["audio"] + source = audio["source"] + formatted_source = {} + if "bytes" in source: + formatted_source = {"bytes": source["bytes"]} + result = {"format": audio["format"], "source": formatted_source} + return {"audio": result} + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html if "image" in content: image = content["image"] diff --git a/src/strands/types/content.py b/src/strands/types/content.py index d75dbb87f..5b103334f 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -11,7 +11,7 @@ from typing_extensions import TypedDict from .citations import CitationsContentBlock -from .media import DocumentContent, ImageContent, VideoContent +from .media import AudioContent, DocumentContent, ImageContent, VideoContent from .tools import ToolResult, ToolUse @@ -75,6 +75,7 @@ class ContentBlock(TypedDict, total=False): """A block of content for a message that you pass to, or receive from, a model. Attributes: + audio: Audio to include in the message. cachePoint: A cache point configuration to optimize conversation history. document: A document to include in the message. guardContent: Contains the content to assess with the guardrail. @@ -87,6 +88,7 @@ class ContentBlock(TypedDict, total=False): citationsContent: Contains the citations for a document. """ + audio: AudioContent cachePoint: CachePoint document: DocumentContent guardContent: GuardContent diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 462d8af34..54ee137aa 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -91,3 +91,29 @@ class VideoContent(TypedDict): format: VideoFormat source: VideoSource + + +AudioFormat = Literal["mp3", "wav", "flac", "ogg", "aac", "webm"] +"""Supported audio formats.""" + + +class AudioSource(TypedDict): + """Contains the content of audio data. + + Attributes: + bytes: The binary content of the audio. + """ + + bytes: bytes + + +class AudioContent(TypedDict): + """Audio to include in a message. + + Attributes: + format: The format of the audio (e.g., "mp3", "wav"). + source: The source containing the audio's binary content. + """ + + format: AudioFormat + source: AudioSource diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 7697c5e03..53220a697 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1888,6 +1888,33 @@ def test_format_request_filters_video_content_blocks(model, model_id): assert "resolution" not in video_block +def test_format_request_filters_audio_content_blocks(model, model_id): + """Test that format_request filters extra fields from audio content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "audio": { + "format": "mp3", + "source": {"bytes": b"audio_data"}, + "duration": 120, # Extra field that should be filtered + "bitrate": "320kbps", # Extra field that should be filtered + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + audio_block = formatted_request["messages"][0]["content"][0]["audio"] + expected = {"format": "mp3", "source": {"bytes": b"audio_data"}} + assert audio_block == expected + assert "duration" not in audio_block + assert "bitrate" not in audio_block + + def test_format_request_filters_cache_point_content_blocks(model, model_id): """Test that format_request filters extra fields from cachePoint content blocks.""" messages = [ diff --git a/tests/strands/tools/test_decorator_pep563.py b/tests/strands/tools/test_decorator_pep563.py index 07ec8f2ba..44d9a626a 100644 --- a/tests/strands/tools/test_decorator_pep563.py +++ b/tests/strands/tools/test_decorator_pep563.py @@ -10,10 +10,10 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal import pytest -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict from strands import tool