-
Notifications
You must be signed in to change notification settings - Fork 655
FEAT: Sora target: support remix, image-to-video #1341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,14 @@ | |
| # Licensed under the MIT license. | ||
|
|
||
| import logging | ||
| from typing import Any | ||
| import os | ||
| from typing import Any, Optional | ||
|
|
||
| from pyrit.exceptions import ( | ||
| pyrit_target_retry, | ||
| ) | ||
| from pyrit.identifiers import TargetIdentifier | ||
| from pyrit.models import ( | ||
| DataTypeSerializer, | ||
| Message, | ||
| MessagePiece, | ||
| construct_response_from_request, | ||
|
|
@@ -27,13 +28,20 @@ class OpenAIVideoTarget(OpenAITarget): | |
|
|
||
| Supports Sora-2 and Sora-2-Pro models via the OpenAI videos API. | ||
|
|
||
| Supports three modes: | ||
| - Text-to-video: Generate video from a text prompt | ||
| - Image-to-video: Generate video using an image as the first frame (include image_path piece) | ||
| - Remix: Create variation of existing video (include video_id in prompt_metadata) | ||
|
|
||
| Supported resolutions: | ||
| - Sora-2: 720x1280, 1280x720 | ||
| - Sora-2-Pro: 720x1280, 1280x720, 1024x1792, 1792x1024 | ||
|
|
||
| Supported durations: 4, 8, or 12 seconds | ||
|
|
||
| Default: resolution="1280x720", duration=4 seconds | ||
|
|
||
| Supported image formats for image-to-video: JPEG, PNG, WEBP | ||
| """ | ||
|
|
||
| SUPPORTED_RESOLUTIONS = ["720x1280", "1280x720", "1024x1792", "1792x1024"] | ||
|
|
@@ -96,20 +104,6 @@ def _get_provider_examples(self) -> dict[str, str]: | |
| "api.openai.com": "https://api.openai.com/v1", | ||
| } | ||
|
|
||
| def _build_identifier(self) -> TargetIdentifier: | ||
| """ | ||
| Build the identifier with video generation-specific parameters. | ||
|
|
||
| Returns: | ||
| TargetIdentifier: The identifier for this target instance. | ||
| """ | ||
| return self._create_identifier( | ||
| target_specific_params={ | ||
| "resolution": self._size, | ||
| "n_seconds": self._n_seconds, | ||
| }, | ||
| ) | ||
|
|
||
| def _validate_resolution(self, *, resolution_dimensions: str) -> str: | ||
| """ | ||
| Validate resolution dimensions. | ||
|
|
@@ -149,6 +143,11 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: | |
| """ | ||
| Asynchronously sends a message and generates a video using the OpenAI SDK. | ||
|
|
||
| Supports three modes: | ||
| - Text-to-video: Single text piece | ||
| - Image-to-video: Text piece + image_path piece (image becomes first frame) | ||
| - Remix: Text piece with prompt_metadata["video_id"] set to an existing video ID | ||
|
|
||
| Args: | ||
| message (Message): The message object containing the prompt. | ||
|
|
||
|
|
@@ -160,23 +159,91 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: | |
| ValueError: If the request is invalid. | ||
| """ | ||
| self._validate_request(message=message) | ||
| message_piece = message.message_pieces[0] | ||
| prompt = message_piece.converted_value | ||
|
|
||
| # Extract pieces by type | ||
| pieces = message.message_pieces | ||
| text_piece = next(p for p in pieces if p.converted_value_data_type == "text") | ||
| image_piece = next((p for p in pieces if p.converted_value_data_type == "image_path"), None) | ||
| prompt = text_piece.converted_value | ||
|
|
||
| # Check for remix mode via prompt_metadata | ||
| remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None | ||
|
|
||
| logger.info(f"Sending video generation prompt: {prompt}") | ||
|
|
||
| # Use unified error handler - automatically detects Video and validates | ||
| response = await self._handle_openai_request( | ||
| api_call=lambda: self._async_client.videos.create_and_poll( | ||
| model=self._model_name, | ||
| prompt=prompt, | ||
| size=self._size, # type: ignore[arg-type] | ||
| seconds=str(self._n_seconds), # type: ignore[arg-type] | ||
| ), | ||
| request=message, | ||
| ) | ||
| if remix_video_id: | ||
| # REMIX MODE: Create variation of existing video | ||
| logger.info(f"Remix mode: Creating variation of video {remix_video_id}") | ||
| response = await self._handle_openai_request( | ||
| api_call=lambda: self._remix_and_poll_async(video_id=remix_video_id, prompt=prompt), | ||
| request=message, | ||
| ) | ||
| elif image_piece: | ||
| # IMAGE-TO-VIDEO MODE: Use image as first frame | ||
| logger.info("Image-to-video mode: Using image as first frame") | ||
| image_path = image_piece.converted_value | ||
| image_serializer = data_serializer_factory( | ||
| value=image_path, data_type="image_path", category="prompt-memory-entries" | ||
| ) | ||
| image_bytes = await image_serializer.read_data() | ||
|
|
||
| # Get MIME type for proper file upload (API requires content-type) | ||
| mime_type = DataTypeSerializer.get_mime_type(image_path) | ||
| if not mime_type: | ||
| # Default to PNG if MIME type cannot be determined | ||
| mime_type = "image/png" | ||
|
|
||
| # Create file tuple with filename and MIME type for OpenAI SDK | ||
| # Format: (filename, content, content_type) | ||
| filename = os.path.basename(image_path) | ||
| input_file = (filename, image_bytes, mime_type) | ||
|
Comment on lines
+182
to
+199
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is a bit too long, way more than 20 lines. Can we split this branch and move some of the logic out to specialized functions? e.g. async def _prepare_image_input_async(self, *, image_path: str) -> tuple[str, bytes, str]:
image_serializer = data_serializer_factory(
value=image_path, data_type="image_path", category="prompt-memory-entries"
)
image_bytes = await image_serializer.read_data()
mime_type = DataTypeSerializer.get_mime_type(image_path) or "image/png"
filename = os.path.basename(image_path)
return (filename, image_bytes, mime_type) |
||
|
|
||
| response = await self._handle_openai_request( | ||
| api_call=lambda: self._async_client.videos.create_and_poll( | ||
| model=self._model_name, | ||
| prompt=prompt, | ||
| size=self._size, # type: ignore[arg-type] | ||
| seconds=str(self._n_seconds), # type: ignore[arg-type] | ||
|
Comment on lines
+205
to
+206
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please do not use VideoSize: Literal["720x1280", "1280x720", "1024x1792", "1792x1024"]
VideoSeconds: Literal["4", "8", "12"] # these are stringsThis code has 2 issues:
I think you should follow the established patterns we have, e.g. openai_tts_target.py and openai_chat_audio_config.py. VideoSize = Literal["720x1280", "1280x720", "1024x1792", "1792x1024"]
VideoSeconds = Literal["4", "8", "12"]
class OpenAIVideoTarget(OpenAITarget):
SUPPORTED_RESOLUTIONS: list[VideoSize] = ["720x1280", "1280x720", "1024x1792", "1792x1024"]
SUPPORTED_DURATIONS: list[VideoSeconds] = ["4", "8", "12"]
def __init__(
self,
*,
resolution_dimensions: VideoSize = "1280x720",
n_seconds: VideoSeconds = "4",
**kwargs: Any,
) -> None:
...
self._n_seconds: VideoSeconds = n_seconds
self._size: VideoSize = resolution_dimensionsThen you don't need def __init__(
self,
*,
resolution_dimensions: VideoSize = "1280x720",
n_seconds: int | VideoSeconds = 4, # you can accept both here
**kwargs: Any,
) -> None:
...
self._n_seconds: VideoSeconds = str(n_seconds) if isinstance(n_seconds, int) else n_seconds
self._validate_duration() |
||
| input_reference=input_file, | ||
| ), | ||
| request=message, | ||
| ) | ||
| else: | ||
| # TEXT-TO-VIDEO MODE: Standard generation | ||
| response = await self._handle_openai_request( | ||
| api_call=lambda: self._async_client.videos.create_and_poll( | ||
| model=self._model_name, | ||
| prompt=prompt, | ||
| size=self._size, # type: ignore[arg-type] | ||
| seconds=str(self._n_seconds), # type: ignore[arg-type] | ||
| ), | ||
| request=message, | ||
| ) | ||
|
|
||
| return [response] | ||
|
|
||
| async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any: | ||
| """ | ||
| Create a remix of an existing video and poll until complete. | ||
|
|
||
| The OpenAI SDK's remix() method returns immediately with a job status. | ||
| This method polls until the job completes or fails. | ||
|
|
||
| Args: | ||
| video_id: The ID of the completed video to remix. | ||
| prompt: The text prompt directing the remix. | ||
|
|
||
| Returns: | ||
| The completed Video object from the OpenAI SDK. | ||
| """ | ||
| video = await self._async_client.videos.remix(video_id, prompt=prompt) | ||
|
|
||
| # Poll until completion if not already done | ||
| if video.status not in ["completed", "failed"]: | ||
| video = await self._async_client.videos.poll(video.id) | ||
|
|
||
| return video | ||
|
|
||
| def _check_content_filter(self, response: Any) -> bool: | ||
| """ | ||
| Check if a video generation response was content filtered. | ||
|
|
@@ -218,13 +285,17 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> | |
| if video.status == "completed": | ||
| logger.info(f"Video generation completed successfully: {video.id}") | ||
|
|
||
| # Log remix metadata if available | ||
| if hasattr(video, "remixed_from_video_id") and video.remixed_from_video_id: | ||
| logger.info(f"Video was remixed from: {video.remixed_from_video_id}") | ||
|
|
||
| # Download video content using SDK | ||
| video_response = await self._async_client.videos.download_content(video.id) | ||
| # Extract bytes from HttpxBinaryResponseContent | ||
| video_content = video_response.content | ||
|
|
||
| # Save the video to storage | ||
| return await self._save_video_response(request=request, video_data=video_content) | ||
| # Save the video to storage (include video.id for chaining remixes) | ||
| return await self._save_video_response(request=request, video_data=video_content, video_id=video.id) | ||
|
|
||
| elif video.status == "failed": | ||
| # Handle failed video generation (non-content-filter) | ||
|
|
@@ -249,13 +320,16 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> | |
| error="unknown", | ||
| ) | ||
|
|
||
| async def _save_video_response(self, *, request: MessagePiece, video_data: bytes) -> Message: | ||
| async def _save_video_response( | ||
| self, *, request: MessagePiece, video_data: bytes, video_id: Optional[str] = None | ||
| ) -> Message: | ||
| """ | ||
| Save video data to storage and construct response. | ||
|
|
||
| Args: | ||
| request: The original request message piece. | ||
| video_data: The video content as bytes. | ||
| video_id: The video ID from the API (stored in metadata for chaining remixes). | ||
|
|
||
| Returns: | ||
| Message: The response with the video file path. | ||
|
|
@@ -267,11 +341,15 @@ async def _save_video_response(self, *, request: MessagePiece, video_data: bytes | |
|
|
||
| logger.info(f"Video saved to: {video_path}") | ||
|
|
||
| # Include video_id in metadata for chaining (e.g., remix the generated video later) | ||
| prompt_metadata = {"video_id": video_id} if video_id else None | ||
|
|
||
| # Construct response | ||
| response_entry = construct_response_from_request( | ||
| request=request, | ||
| response_text_pieces=[video_path], | ||
| response_type="video_path", | ||
| prompt_metadata=prompt_metadata, | ||
| ) | ||
|
|
||
| return response_entry | ||
|
|
@@ -280,19 +358,45 @@ def _validate_request(self, *, message: Message) -> None: | |
| """ | ||
| Validate the request message. | ||
|
|
||
| Accepts: | ||
| - Single text piece (text-to-video or remix mode) | ||
| - Text piece + image_path piece (image-to-video mode) | ||
|
|
||
| Args: | ||
| message: The message to validate. | ||
|
|
||
| Raises: | ||
| ValueError: If the request is invalid. | ||
| """ | ||
| n_pieces = len(message.message_pieces) | ||
| if n_pieces != 1: | ||
| raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") | ||
|
|
||
| piece_type = message.message_pieces[0].converted_value_data_type | ||
| if piece_type != "text": | ||
| raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") | ||
| pieces = message.message_pieces | ||
| n_pieces = len(pieces) | ||
|
|
||
| if n_pieces == 0: | ||
| raise ValueError("Message must contain at least one piece.") | ||
|
|
||
| # Categorize pieces | ||
| text_pieces = [p for p in pieces if p.converted_value_data_type == "text"] | ||
| image_pieces = [p for p in pieces if p.converted_value_data_type == "image_path"] | ||
| other_pieces = [p for p in pieces if p.converted_value_data_type not in ("text", "image_path")] | ||
|
Comment on lines
+378
to
+380
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use the function I suggested above here: text_pieces = message.get_pieces_by_type(data_type="text")
image_pieces = message.get_pieces_by_type(data_type="image_path")
other_pieces = message.get_pieces_by_type(exclude_types={"text", "image_path"}) |
||
|
|
||
| # Must have exactly one text piece | ||
| if len(text_pieces) != 1: | ||
| raise ValueError(f"Expected exactly 1 text piece, got {len(text_pieces)}.") | ||
|
|
||
| # At most one image piece | ||
| if len(image_pieces) > 1: | ||
| raise ValueError(f"Expected at most 1 image piece, got {len(image_pieces)}.") | ||
|
|
||
| # No other data types allowed | ||
| if other_pieces: | ||
| types = [p.converted_value_data_type for p in other_pieces] | ||
| raise ValueError(f"Unsupported piece types: {types}. Only 'text' and 'image_path' are supported.") | ||
|
|
||
| # Check for conflicting modes: remix + image | ||
| text_piece = text_pieces[0] | ||
| remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None | ||
| if remix_video_id and image_pieces: | ||
| raise ValueError("Cannot use image input in remix mode. Remix uses existing video as reference.") | ||
|
|
||
| def is_json_response_supported(self) -> bool: | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can make this a function inside the
Messageclass, something like this:or if you want it to be very efficient (don't mind a bit of code duplication), you can use the
nextforget_piece_by_type:Then your code will become: