diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index f6915c027..e34bdec97 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -2,13 +2,14 @@ # Licensed under the MIT license. import logging -from typing import Any +import os +from typing import Any, Optional from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.identifiers import TargetIdentifier from pyrit.models import ( + DataTypeSerializer, Message, MessagePiece, construct_response_from_request, @@ -27,6 +28,11 @@ class OpenAIVideoTarget(OpenAITarget): Supports Sora-2 and Sora-2-Pro models via the OpenAI videos API. + Supports three modes: + - Text-to-video: Generate video from a text prompt + - Image-to-video: Generate video using an image as the first frame (include image_path piece) + - Remix: Create variation of existing video (include video_id in prompt_metadata) + Supported resolutions: - Sora-2: 720x1280, 1280x720 - Sora-2-Pro: 720x1280, 1280x720, 1024x1792, 1792x1024 @@ -34,6 +40,8 @@ class OpenAIVideoTarget(OpenAITarget): Supported durations: 4, 8, or 12 seconds Default: resolution="1280x720", duration=4 seconds + + Supported image formats for image-to-video: JPEG, PNG, WEBP """ SUPPORTED_RESOLUTIONS = ["720x1280", "1280x720", "1024x1792", "1792x1024"] @@ -96,20 +104,6 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } - def _build_identifier(self) -> TargetIdentifier: - """ - Build the identifier with video generation-specific parameters. - - Returns: - TargetIdentifier: The identifier for this target instance. - """ - return self._create_identifier( - target_specific_params={ - "resolution": self._size, - "n_seconds": self._n_seconds, - }, - ) - def _validate_resolution(self, *, resolution_dimensions: str) -> str: """ Validate resolution dimensions. @@ -149,6 +143,11 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Asynchronously sends a message and generates a video using the OpenAI SDK. + Supports three modes: + - Text-to-video: Single text piece + - Image-to-video: Text piece + image_path piece (image becomes first frame) + - Remix: Text piece with prompt_metadata["video_id"] set to an existing video ID + Args: message (Message): The message object containing the prompt. @@ -160,23 +159,91 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ValueError: If the request is invalid. """ self._validate_request(message=message) - message_piece = message.message_pieces[0] - prompt = message_piece.converted_value + + # Extract pieces by type + pieces = message.message_pieces + text_piece = next(p for p in pieces if p.converted_value_data_type == "text") + image_piece = next((p for p in pieces if p.converted_value_data_type == "image_path"), None) + prompt = text_piece.converted_value + + # Check for remix mode via prompt_metadata + remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None logger.info(f"Sending video generation prompt: {prompt}") - # Use unified error handler - automatically detects Video and validates - response = await self._handle_openai_request( - api_call=lambda: self._async_client.videos.create_and_poll( - model=self._model_name, - prompt=prompt, - size=self._size, # type: ignore[arg-type] - seconds=str(self._n_seconds), # type: ignore[arg-type] - ), - request=message, - ) + if remix_video_id: + # REMIX MODE: Create variation of existing video + logger.info(f"Remix mode: Creating variation of video {remix_video_id}") + response = await self._handle_openai_request( + api_call=lambda: self._remix_and_poll_async(video_id=remix_video_id, prompt=prompt), + request=message, + ) + elif image_piece: + # IMAGE-TO-VIDEO MODE: Use image as first frame + logger.info("Image-to-video mode: Using image as first frame") + image_path = image_piece.converted_value + image_serializer = data_serializer_factory( + value=image_path, data_type="image_path", category="prompt-memory-entries" + ) + image_bytes = await image_serializer.read_data() + + # Get MIME type for proper file upload (API requires content-type) + mime_type = DataTypeSerializer.get_mime_type(image_path) + if not mime_type: + # Default to PNG if MIME type cannot be determined + mime_type = "image/png" + + # Create file tuple with filename and MIME type for OpenAI SDK + # Format: (filename, content, content_type) + filename = os.path.basename(image_path) + input_file = (filename, image_bytes, mime_type) + + response = await self._handle_openai_request( + api_call=lambda: self._async_client.videos.create_and_poll( + model=self._model_name, + prompt=prompt, + size=self._size, # type: ignore[arg-type] + seconds=str(self._n_seconds), # type: ignore[arg-type] + input_reference=input_file, + ), + request=message, + ) + else: + # TEXT-TO-VIDEO MODE: Standard generation + response = await self._handle_openai_request( + api_call=lambda: self._async_client.videos.create_and_poll( + model=self._model_name, + prompt=prompt, + size=self._size, # type: ignore[arg-type] + seconds=str(self._n_seconds), # type: ignore[arg-type] + ), + request=message, + ) + return [response] + async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any: + """ + Create a remix of an existing video and poll until complete. + + The OpenAI SDK's remix() method returns immediately with a job status. + This method polls until the job completes or fails. + + Args: + video_id: The ID of the completed video to remix. + prompt: The text prompt directing the remix. + + Returns: + The completed Video object from the OpenAI SDK. + """ + video = await self._async_client.videos.remix(video_id, prompt=prompt) + + # Poll until completion if not already done + if video.status not in ["completed", "failed"]: + video = await self._async_client.videos.poll(video.id) + + return video + def _check_content_filter(self, response: Any) -> bool: """ Check if a video generation response was content filtered. @@ -218,13 +285,17 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> if video.status == "completed": logger.info(f"Video generation completed successfully: {video.id}") + # Log remix metadata if available + if hasattr(video, "remixed_from_video_id") and video.remixed_from_video_id: + logger.info(f"Video was remixed from: {video.remixed_from_video_id}") + # Download video content using SDK video_response = await self._async_client.videos.download_content(video.id) # Extract bytes from HttpxBinaryResponseContent video_content = video_response.content - # Save the video to storage - return await self._save_video_response(request=request, video_data=video_content) + # Save the video to storage (include video.id for chaining remixes) + return await self._save_video_response(request=request, video_data=video_content, video_id=video.id) elif video.status == "failed": # Handle failed video generation (non-content-filter) @@ -249,13 +320,16 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> error="unknown", ) - async def _save_video_response(self, *, request: MessagePiece, video_data: bytes) -> Message: + async def _save_video_response( + self, *, request: MessagePiece, video_data: bytes, video_id: Optional[str] = None + ) -> Message: """ Save video data to storage and construct response. Args: request: The original request message piece. video_data: The video content as bytes. + video_id: The video ID from the API (stored in metadata for chaining remixes). Returns: Message: The response with the video file path. @@ -267,11 +341,15 @@ async def _save_video_response(self, *, request: MessagePiece, video_data: bytes logger.info(f"Video saved to: {video_path}") + # Include video_id in metadata for chaining (e.g., remix the generated video later) + prompt_metadata = {"video_id": video_id} if video_id else None + # Construct response response_entry = construct_response_from_request( request=request, response_text_pieces=[video_path], response_type="video_path", + prompt_metadata=prompt_metadata, ) return response_entry @@ -280,19 +358,45 @@ def _validate_request(self, *, message: Message) -> None: """ Validate the request message. + Accepts: + - Single text piece (text-to-video or remix mode) + - Text piece + image_path piece (image-to-video mode) + Args: message: The message to validate. Raises: ValueError: If the request is invalid. """ - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") + pieces = message.message_pieces + n_pieces = len(pieces) + + if n_pieces == 0: + raise ValueError("Message must contain at least one piece.") + + # Categorize pieces + text_pieces = [p for p in pieces if p.converted_value_data_type == "text"] + image_pieces = [p for p in pieces if p.converted_value_data_type == "image_path"] + other_pieces = [p for p in pieces if p.converted_value_data_type not in ("text", "image_path")] + + # Must have exactly one text piece + if len(text_pieces) != 1: + raise ValueError(f"Expected exactly 1 text piece, got {len(text_pieces)}.") + + # At most one image piece + if len(image_pieces) > 1: + raise ValueError(f"Expected at most 1 image piece, got {len(image_pieces)}.") + + # No other data types allowed + if other_pieces: + types = [p.converted_value_data_type for p in other_pieces] + raise ValueError(f"Unsupported piece types: {types}. Only 'text' and 'image_path' are supported.") + + # Check for conflicting modes: remix + image + text_piece = text_pieces[0] + remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None + if remix_video_id and image_pieces: + raise ValueError("Cannot use image input in remix mode. Remix uses existing video as reference.") def is_json_response_supported(self) -> bool: """ diff --git a/tests/unit/target/test_video_target.py b/tests/unit/target/test_video_target.py index dbf16e6bc..a17835f57 100644 --- a/tests/unit/target/test_video_target.py +++ b/tests/unit/target/test_video_target.py @@ -54,8 +54,9 @@ def test_video_initialization_invalid_duration(patch_central_database): ) -def test_video_validate_request_length(video_target: OpenAIVideoTarget): - with pytest.raises(ValueError, match="single message piece"): +def test_video_validate_request_multiple_text_pieces(video_target: OpenAIVideoTarget): + """Test validation rejects multiple text pieces.""" + with pytest.raises(ValueError, match="Expected exactly 1 text piece"): conversation_id = str(uuid.uuid4()) msg1 = MessagePiece( role="user", original_value="test1", converted_value="test1", conversation_id=conversation_id @@ -66,8 +67,9 @@ def test_video_validate_request_length(video_target: OpenAIVideoTarget): video_target._validate_request(message=Message([msg1, msg2])) -def test_video_validate_prompt_type(video_target: OpenAIVideoTarget): - with pytest.raises(ValueError, match="text prompt input"): +def test_video_validate_prompt_type_image_only(video_target: OpenAIVideoTarget): + """Test validation rejects image-only input (must have text).""" + with pytest.raises(ValueError, match="Expected exactly 1 text piece"): msg = MessagePiece( role="user", original_value="test", converted_value="test", converted_value_data_type="image_path" ) @@ -348,3 +350,528 @@ def test_check_content_filter_no_error_object(video_target: OpenAIVideoTarget): mock_video.error = None assert video_target._check_content_filter(mock_video) is False + + +# Tests for image-to-video and remix features + + +class TestVideoTargetValidation: + """Tests for video target validation with new features.""" + + def test_validate_accepts_text_only(self, video_target: OpenAIVideoTarget): + """Test validation accepts single text piece (text-to-video mode).""" + msg = MessagePiece(role="user", original_value="test prompt", converted_value="test prompt") + # Should not raise + video_target._validate_request(message=Message([msg])) + + def test_validate_accepts_text_and_image(self, video_target: OpenAIVideoTarget): + """Test validation accepts text + image (image-to-video mode).""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate this", + converted_value="animate this", + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + # Should not raise + video_target._validate_request(message=Message([msg_text, msg_image])) + + def test_validate_rejects_multiple_images(self, video_target: OpenAIVideoTarget): + """Test validation rejects multiple image pieces.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate", + converted_value="animate", + conversation_id=conversation_id, + ) + msg_img1 = MessagePiece( + role="user", + original_value="/path/img1.png", + converted_value="/path/img1.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + msg_img2 = MessagePiece( + role="user", + original_value="/path/img2.png", + converted_value="/path/img2.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + with pytest.raises(ValueError, match="at most 1 image piece"): + video_target._validate_request(message=Message([msg_text, msg_img1, msg_img2])) + + def test_validate_rejects_unsupported_types(self, video_target: OpenAIVideoTarget): + """Test validation rejects unsupported data types.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="test", + converted_value="test", + conversation_id=conversation_id, + ) + msg_audio = MessagePiece( + role="user", + original_value="/path/audio.wav", + converted_value="/path/audio.wav", + converted_value_data_type="audio_path", + conversation_id=conversation_id, + ) + with pytest.raises(ValueError, match="Unsupported piece types"): + video_target._validate_request(message=Message([msg_text, msg_audio])) + + def test_validate_rejects_remix_with_image(self, video_target: OpenAIVideoTarget): + """Test validation rejects remix mode combined with image input.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="remix prompt", + converted_value="remix prompt", + prompt_metadata={"video_id": "existing_video_123"}, + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + with pytest.raises(ValueError, match="Cannot use image input in remix mode"): + video_target._validate_request(message=Message([msg_text, msg_image])) + + +@pytest.mark.usefixtures("patch_central_database") +class TestVideoTargetImageToVideo: + """Tests for image-to-video functionality.""" + + @pytest.fixture + def video_target(self) -> OpenAIVideoTarget: + return OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + + @pytest.mark.asyncio + async def test_image_to_video_calls_create_with_input_reference(self, video_target: OpenAIVideoTarget): + """Test that image-to-video mode passes input_reference to create_and_poll.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate this image", + converted_value="animate this image", + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + + mock_video = MagicMock() + mock_video.id = "video_img2vid" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = None + + mock_video_response = MagicMock() + mock_video_response.content = b"video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/output.mp4" + mock_serializer.save_data = AsyncMock() + + mock_image_serializer = MagicMock() + mock_image_serializer.read_data = AsyncMock(return_value=b"image bytes") + + with ( + patch.object(video_target._async_client.videos, "create_and_poll", new_callable=AsyncMock) as mock_create, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + patch("pyrit.prompt_target.openai.openai_video_target.DataTypeSerializer.get_mime_type") as mock_mime, + ): + # First call returns image serializer, second call returns video serializer + mock_factory.side_effect = [mock_image_serializer, mock_serializer] + mock_create.return_value = mock_video + mock_download.return_value = mock_video_response + mock_mime.return_value = "image/png" + + response = await video_target.send_prompt_async(message=Message([msg_text, msg_image])) + + # Verify create_and_poll was called with input_reference as tuple with MIME type + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + # input_reference should be (filename, bytes, content_type) tuple + input_ref = call_kwargs["input_reference"] + assert isinstance(input_ref, tuple) + assert input_ref[0] == "image.png" # filename + assert input_ref[1] == b"image bytes" # content + assert input_ref[2] == "image/png" # MIME type + assert call_kwargs["prompt"] == "animate this image" + + # Verify response + assert len(response) == 1 + assert response[0].message_pieces[0].converted_value_data_type == "video_path" + + +@pytest.mark.usefixtures("patch_central_database") +class TestVideoTargetRemix: + """Tests for video remix functionality.""" + + @pytest.fixture + def video_target(self) -> OpenAIVideoTarget: + return OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + + @pytest.mark.asyncio + async def test_remix_calls_remix_and_poll(self, video_target: OpenAIVideoTarget): + """Test that remix mode calls remix() and poll().""" + msg = MessagePiece( + role="user", + original_value="make it more dramatic", + converted_value="make it more dramatic", + prompt_metadata={"video_id": "existing_video_123"}, + conversation_id=str(uuid.uuid4()), + ) + + mock_remix_video = MagicMock() + mock_remix_video.id = "remixed_video_456" + mock_remix_video.status = "in_progress" + + mock_polled_video = MagicMock() + mock_polled_video.id = "remixed_video_456" + mock_polled_video.status = "completed" + mock_polled_video.error = None + mock_polled_video.remixed_from_video_id = "existing_video_123" + + mock_video_response = MagicMock() + mock_video_response.content = b"remixed video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/remixed.mp4" + mock_serializer.save_data = AsyncMock() + + with ( + patch.object(video_target._async_client.videos, "remix", new_callable=AsyncMock) as mock_remix, + patch.object(video_target._async_client.videos, "poll", new_callable=AsyncMock) as mock_poll, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + ): + mock_remix.return_value = mock_remix_video + mock_poll.return_value = mock_polled_video + mock_download.return_value = mock_video_response + mock_factory.return_value = mock_serializer + + response = await video_target.send_prompt_async(message=Message([msg])) + + # Verify remix was called with correct params + mock_remix.assert_called_once_with("existing_video_123", prompt="make it more dramatic") + # Verify poll was called (since status was in_progress) + mock_poll.assert_called_once_with("remixed_video_456") + + # Verify response + assert len(response) == 1 + assert response[0].message_pieces[0].converted_value_data_type == "video_path" + + @pytest.mark.asyncio + async def test_remix_skips_poll_if_completed(self, video_target: OpenAIVideoTarget): + """Test that remix mode skips poll() if already completed.""" + msg = MessagePiece( + role="user", + original_value="remix prompt", + converted_value="remix prompt", + prompt_metadata={"video_id": "existing_video_123"}, + conversation_id=str(uuid.uuid4()), + ) + + mock_video = MagicMock() + mock_video.id = "remixed_video" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = "existing_video_123" + + mock_video_response = MagicMock() + mock_video_response.content = b"remixed video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/remixed.mp4" + mock_serializer.save_data = AsyncMock() + + with ( + patch.object(video_target._async_client.videos, "remix", new_callable=AsyncMock) as mock_remix, + patch.object(video_target._async_client.videos, "poll", new_callable=AsyncMock) as mock_poll, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + ): + mock_remix.return_value = mock_video + mock_download.return_value = mock_video_response + mock_factory.return_value = mock_serializer + + await video_target.send_prompt_async(message=Message([msg])) + + # Verify poll was NOT called since status was already completed + mock_poll.assert_not_called() + + +@pytest.mark.usefixtures("patch_central_database") +class TestVideoTargetMetadata: + """Tests for video_id metadata storage in responses.""" + + @pytest.fixture + def video_target(self) -> OpenAIVideoTarget: + return OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + + @pytest.mark.asyncio + async def test_response_includes_video_id_metadata(self, video_target: OpenAIVideoTarget): + """Test that response includes video_id in prompt_metadata for chaining.""" + msg = MessagePiece( + role="user", + original_value="test prompt", + converted_value="test prompt", + conversation_id=str(uuid.uuid4()), + ) + + mock_video = MagicMock() + mock_video.id = "new_video_789" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = None + + mock_video_response = MagicMock() + mock_video_response.content = b"video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/video.mp4" + mock_serializer.save_data = AsyncMock() + + with ( + patch.object(video_target._async_client.videos, "create_and_poll", new_callable=AsyncMock) as mock_create, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + ): + mock_create.return_value = mock_video + mock_download.return_value = mock_video_response + mock_factory.return_value = mock_serializer + + response = await video_target.send_prompt_async(message=Message([msg])) + + # Verify response contains video_id in metadata for chaining + response_piece = response[0].message_pieces[0] + assert response_piece.prompt_metadata is not None + assert response_piece.prompt_metadata.get("video_id") == "new_video_789" + + +@pytest.mark.usefixtures("patch_central_database") +class TestVideoTargetEdgeCases: + """Tests for edge cases and error scenarios.""" + + @pytest.fixture + def video_target(self) -> OpenAIVideoTarget: + return OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + + def test_validate_rejects_empty_message(self, video_target: OpenAIVideoTarget): + """Test that empty messages are rejected (by Message constructor).""" + with pytest.raises(ValueError, match="at least one message piece"): + Message([]) + + def test_validate_rejects_no_text_piece(self, video_target: OpenAIVideoTarget): + """Test validation rejects message without text piece.""" + msg = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + ) + with pytest.raises(ValueError, match="Expected exactly 1 text piece"): + video_target._validate_request(message=Message([msg])) + + @pytest.mark.asyncio + async def test_image_to_video_with_jpeg(self, video_target: OpenAIVideoTarget): + """Test image-to-video with JPEG image format.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate", + converted_value="animate", + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.jpg", + converted_value="/path/image.jpg", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + + mock_video = MagicMock() + mock_video.id = "video_jpeg" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = None + + mock_video_response = MagicMock() + mock_video_response.content = b"video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/output.mp4" + mock_serializer.save_data = AsyncMock() + + mock_image_serializer = MagicMock() + mock_image_serializer.read_data = AsyncMock(return_value=b"jpeg bytes") + + with ( + patch.object(video_target._async_client.videos, "create_and_poll", new_callable=AsyncMock) as mock_create, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + patch("pyrit.prompt_target.openai.openai_video_target.DataTypeSerializer.get_mime_type") as mock_mime, + ): + mock_factory.side_effect = [mock_image_serializer, mock_serializer] + mock_create.return_value = mock_video + mock_download.return_value = mock_video_response + mock_mime.return_value = "image/jpeg" + + response = await video_target.send_prompt_async(message=Message([msg_text, msg_image])) + + # Verify JPEG MIME type is used + call_kwargs = mock_create.call_args.kwargs + input_ref = call_kwargs["input_reference"] + assert input_ref[2] == "image/jpeg" + + @pytest.mark.asyncio + async def test_image_to_video_with_unknown_mime_defaults_to_png(self, video_target: OpenAIVideoTarget): + """Test image-to-video defaults to PNG when MIME type cannot be determined.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate", + converted_value="animate", + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.unknown", + converted_value="/path/image.unknown", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + + mock_video = MagicMock() + mock_video.id = "video_unknown" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = None + + mock_video_response = MagicMock() + mock_video_response.content = b"video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/output.mp4" + mock_serializer.save_data = AsyncMock() + + mock_image_serializer = MagicMock() + mock_image_serializer.read_data = AsyncMock(return_value=b"unknown bytes") + + with ( + patch.object(video_target._async_client.videos, "create_and_poll", new_callable=AsyncMock) as mock_create, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + patch("pyrit.prompt_target.openai.openai_video_target.DataTypeSerializer.get_mime_type") as mock_mime, + ): + mock_factory.side_effect = [mock_image_serializer, mock_serializer] + mock_create.return_value = mock_video + mock_download.return_value = mock_video_response + mock_mime.return_value = None # MIME type cannot be determined + + response = await video_target.send_prompt_async(message=Message([msg_text, msg_image])) + + # Verify default PNG MIME type is used + call_kwargs = mock_create.call_args.kwargs + input_ref = call_kwargs["input_reference"] + assert input_ref[2] == "image/png" # Default + + @pytest.mark.asyncio + async def test_remix_with_failed_status(self, video_target: OpenAIVideoTarget): + """Test remix mode handles failed video generation.""" + msg = MessagePiece( + role="user", + original_value="remix this", + converted_value="remix this", + prompt_metadata={"video_id": "existing_video"}, + conversation_id=str(uuid.uuid4()), + ) + + mock_video = MagicMock() + mock_video.id = "failed_remix" + mock_video.status = "failed" + mock_error = MagicMock() + mock_error.code = "internal_error" + mock_video.error = mock_error + + with ( + patch.object(video_target._async_client.videos, "remix", new_callable=AsyncMock) as mock_remix, + patch.object(video_target._async_client.videos, "poll", new_callable=AsyncMock) as mock_poll, + ): + mock_remix.return_value = mock_video + # Don't need poll since status is already "failed" + + response = await video_target.send_prompt_async(message=Message([msg])) + + # Verify response is processing error + response_piece = response[0].message_pieces[0] + assert response_piece.response_error == "processing" + + def test_supported_resolutions(self, video_target: OpenAIVideoTarget): + """Test that all supported resolutions are valid.""" + for resolution in OpenAIVideoTarget.SUPPORTED_RESOLUTIONS: + target = OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + resolution_dimensions=resolution, + ) + assert target._size == resolution + + def test_supported_durations(self, video_target: OpenAIVideoTarget): + """Test that all supported durations are valid.""" + for duration in OpenAIVideoTarget.SUPPORTED_DURATIONS: + target = OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + n_seconds=duration, + ) + assert target._n_seconds == duration