diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8e1558ca7..871b0db47 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -414,9 +414,13 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "format" in document: result["format"] = document["format"] - # Handle source + # Handle source (supports both bytes and s3Location) if "source" in document: - result["source"] = {"bytes": document["source"]["bytes"]} + source = document["source"] + if "bytes" in source: + result["source"] = {"bytes": source["bytes"]} + elif "s3Location" in source: + result["source"] = {"s3Location": source["s3Location"]} # Handle optional fields if "citations" in document and document["citations"] is not None: @@ -437,11 +441,12 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "image" in content: image = content["image"] source = image["source"] - formatted_source = {} + image_source: dict[str, Any] = {} if "bytes" in source: - formatted_source = {"bytes": source["bytes"]} - result = {"format": image["format"], "source": formatted_source} - return {"image": result} + image_source = {"bytes": source["bytes"]} + elif "s3Location" in source: + image_source = {"s3Location": source["s3Location"]} + return {"image": {"format": image["format"], "source": image_source}} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html if "reasoningContent" in content: @@ -502,11 +507,12 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "video" in content: video = content["video"] source = video["source"] - formatted_source = {} + video_source: dict[str, Any] = {} if "bytes" in source: - formatted_source = {"bytes": source["bytes"]} - result = {"format": video["format"], "source": formatted_source} - return {"video": result} + video_source = {"bytes": source["bytes"]} + elif "s3Location" in source: + video_source = {"s3Location": source["s3Location"]} + return {"video": {"format": video["format"], "source": video_source}} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html if "citationsContent" in content: diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 69cd60cf3..d2ccb0bcb 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -11,18 +11,33 @@ from .citations import CitationsConfig + +class S3Location(TypedDict, total=False): + """S3 location for media content. + + Attributes: + uri: The S3 URI of the content. + bucketOwner: The account ID of the bucket owner. + """ + + uri: str + bucketOwner: str + + DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] """Supported document formats.""" -class DocumentSource(TypedDict): +class DocumentSource(TypedDict, total=False): """Contains the content of a document. Attributes: bytes: The binary content of the document. + s3Location: The S3 location of the document. """ bytes: bytes + s3Location: S3Location class DocumentContent(TypedDict, total=False): @@ -45,14 +60,16 @@ class DocumentContent(TypedDict, total=False): """Supported image formats.""" -class ImageSource(TypedDict): +class ImageSource(TypedDict, total=False): """Contains the content of an image. Attributes: bytes: The binary content of the image. + s3Location: The S3 location of the image. """ bytes: bytes + s3Location: S3Location class ImageContent(TypedDict): @@ -71,14 +88,16 @@ class ImageContent(TypedDict): """Supported video formats.""" -class VideoSource(TypedDict): +class VideoSource(TypedDict, total=False): """Contains the content of a video. Attributes: bytes: The binary content of the video. + s3Location: The S3 location of the video. """ bytes: bytes + s3Location: S3Location class VideoContent(TypedDict): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 7697c5e03..9de08c4ea 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2240,3 +2240,76 @@ async def test_format_request_with_guardrail_latest_message(model): # Latest user message image should also be wrapped assert "guardContent" in formatted_messages[2]["content"][1] assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" + + +def test_format_request_s3_location_document_source(model, model_id): + """Test that s3Location source is supported for documents when bytes is not present.""" + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "name": "test.pdf", + "format": "pdf", + "source": {"s3Location": {"uri": "s3://bucket/key.pdf"}}, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + document_block = formatted_request["messages"][0]["content"][0]["document"] + expected = {"name": "test.pdf", "format": "pdf", "source": {"s3Location": {"uri": "s3://bucket/key.pdf"}}} + assert document_block == expected + + +def test_format_request_s3_location_image_source(model, model_id): + """Test that s3Location source is supported for images when bytes is not present.""" + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": {"s3Location": {"uri": "s3://bucket/image.png"}}, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + image_block = formatted_request["messages"][0]["content"][0]["image"] + expected = {"format": "png", "source": {"s3Location": {"uri": "s3://bucket/image.png"}}} + assert image_block == expected + + +def test_format_request_s3_location_video_source(model, model_id): + """Test that s3Location source is supported for videos when bytes is not present.""" + messages = [ + { + "role": "user", + "content": [ + { + "video": { + "format": "mp4", + "source": {"s3Location": {"uri": "s3://bucket/video.mp4", "bucketOwner": "123456789012"}}, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + video_block = formatted_request["messages"][0]["content"][0]["video"] + expected = { + "format": "mp4", + "source": {"s3Location": {"uri": "s3://bucket/video.mp4", "bucketOwner": "123456789012"}}, + } + assert video_block == expected diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py index 9cb90167d..af56cf367 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -824,3 +824,56 @@ def short_names_only(tool) -> bool: # Should only include short tool (name length <= 10) assert len(result) == 1 assert result[0] is mock_agent_tool1 + + +def test_is_session_active_with_close_future_done(): + """Test that _is_session_active returns False when close_future is done.""" + from unittest.mock import Mock + + client = MCPClient(transport_callable=lambda: Mock()) + + # Mock background thread as alive + client._background_thread = Mock() + client._background_thread.is_alive.return_value = True + + # Mock close_future as done + client._close_future = Mock() + client._close_future.done.return_value = True + + # Should return False because close_future is done + assert client._is_session_active() is False + + +def test_is_session_active_with_close_future_not_done(): + """Test that _is_session_active returns True when close_future is not done.""" + from unittest.mock import Mock + + client = MCPClient(transport_callable=lambda: Mock()) + + # Mock background thread as alive + client._background_thread = Mock() + client._background_thread.is_alive.return_value = True + + # Mock close_future as not done + client._close_future = Mock() + client._close_future.done.return_value = False + + # Should return True + assert client._is_session_active() is True + + +def test_is_session_active_with_none_close_future(): + """Test that _is_session_active returns True when close_future is None.""" + from unittest.mock import Mock + + client = MCPClient(transport_callable=lambda: Mock()) + + # Mock background thread as alive + client._background_thread = Mock() + client._background_thread.is_alive.return_value = True + + # close_future is None + client._close_future = None + + # Should return True + assert client._is_session_active() is True