Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 142 additions & 38 deletions pyrit/prompt_target/openai/openai_video_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# Licensed under the MIT license.

import logging
from typing import Any
import os
from typing import Any, Optional

from pyrit.exceptions import (
pyrit_target_retry,
)
from pyrit.identifiers import TargetIdentifier
from pyrit.models import (
DataTypeSerializer,
Message,
MessagePiece,
construct_response_from_request,
Expand All @@ -27,13 +28,20 @@ class OpenAIVideoTarget(OpenAITarget):

Supports Sora-2 and Sora-2-Pro models via the OpenAI videos API.

Supports three modes:
- Text-to-video: Generate video from a text prompt
- Image-to-video: Generate video using an image as the first frame (include image_path piece)
- Remix: Create variation of existing video (include video_id in prompt_metadata)

Supported resolutions:
- Sora-2: 720x1280, 1280x720
- Sora-2-Pro: 720x1280, 1280x720, 1024x1792, 1792x1024

Supported durations: 4, 8, or 12 seconds

Default: resolution="1280x720", duration=4 seconds

Supported image formats for image-to-video: JPEG, PNG, WEBP
"""

SUPPORTED_RESOLUTIONS = ["720x1280", "1280x720", "1024x1792", "1792x1024"]
Expand Down Expand Up @@ -96,20 +104,6 @@ def _get_provider_examples(self) -> dict[str, str]:
"api.openai.com": "https://api.openai.com/v1",
}

def _build_identifier(self) -> TargetIdentifier:
"""
Build the identifier with video generation-specific parameters.

Returns:
TargetIdentifier: The identifier for this target instance.
"""
return self._create_identifier(
target_specific_params={
"resolution": self._size,
"n_seconds": self._n_seconds,
},
)

def _validate_resolution(self, *, resolution_dimensions: str) -> str:
"""
Validate resolution dimensions.
Expand Down Expand Up @@ -149,6 +143,11 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
"""
Asynchronously sends a message and generates a video using the OpenAI SDK.

Supports three modes:
- Text-to-video: Single text piece
- Image-to-video: Text piece + image_path piece (image becomes first frame)
- Remix: Text piece with prompt_metadata["video_id"] set to an existing video ID

Args:
message (Message): The message object containing the prompt.

Expand All @@ -160,23 +159,91 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
ValueError: If the request is invalid.
"""
self._validate_request(message=message)
message_piece = message.message_pieces[0]
prompt = message_piece.converted_value

# Extract pieces by type
pieces = message.message_pieces
text_piece = next(p for p in pieces if p.converted_value_data_type == "text")
image_piece = next((p for p in pieces if p.converted_value_data_type == "image_path"), None)
Comment on lines +165 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make this a function inside the Message class, something like this:

def get_pieces_by_type(
    self,
    *,
    data_type: Optional[str] = None,
    exclude_types: Optional[set[str]] = None
) -> list[MessagePiece]:
    if data_type is not None and exclude_types is not None:
        raise ValueError("Cannot specify both data_type and exclude_types.")

    if data_type is not None:
        return [p for p in self.message_pieces if p.converted_value_data_type == data_type]

    if exclude_types is not None:
        return [p for p in self.message_pieces if p.converted_value_data_type not in exclude_types]

    return list(self.message_pieces)

def get_piece_by_type(
    self,
    *,
    data_type: str,
    required: bool = True
) -> Optional[MessagePiece]:
    pieces = self.get_pieces_by_type(data_type=data_type)
    
    if not pieces:
        if required:
            raise ValueError(f"No message piece with data type '{data_type}' found.")
        return None
    
    return pieces[0]

or if you want it to be very efficient (don't mind a bit of code duplication), you can use the next for get_piece_by_type:

def get_piece_by_type(
    self,
    *,
    data_type: str,
    required: bool = True
) -> Optional[MessagePiece]:
    piece = next(
        (p for p in self.message_pieces if p.converted_value_data_type == data_type),
        None
    )
    if piece is None and required:
        raise ValueError(f"No message piece with data type '{data_type}' found.")
    return piece

def get_pieces_by_type(
    self,
    *,
    data_type: Optional[str] = None,
    exclude_types: Optional[set[str]] = None
) -> list[MessagePiece]:
    if data_type is not None and exclude_types is not None:
        raise ValueError("Cannot specify both data_type and exclude_types.")

    if data_type is not None:
        return [p for p in self.message_pieces if p.converted_value_data_type == data_type]
    
    if exclude_types is not None:
        return [p for p in self.message_pieces if p.converted_value_data_type not in exclude_types]
    
    return list(self.message_pieces)

Then your code will become:

text_piece = message.get_piece_by_type(data_type="text")
image_piece = message.get_piece_by_type(data_type="image_path", required=False)

prompt = text_piece.converted_value

# Check for remix mode via prompt_metadata
remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None

logger.info(f"Sending video generation prompt: {prompt}")

# Use unified error handler - automatically detects Video and validates
response = await self._handle_openai_request(
api_call=lambda: self._async_client.videos.create_and_poll(
model=self._model_name,
prompt=prompt,
size=self._size, # type: ignore[arg-type]
seconds=str(self._n_seconds), # type: ignore[arg-type]
),
request=message,
)
if remix_video_id:
# REMIX MODE: Create variation of existing video
logger.info(f"Remix mode: Creating variation of video {remix_video_id}")
response = await self._handle_openai_request(
api_call=lambda: self._remix_and_poll_async(video_id=remix_video_id, prompt=prompt),
request=message,
)
elif image_piece:
# IMAGE-TO-VIDEO MODE: Use image as first frame
logger.info("Image-to-video mode: Using image as first frame")
image_path = image_piece.converted_value
image_serializer = data_serializer_factory(
value=image_path, data_type="image_path", category="prompt-memory-entries"
)
image_bytes = await image_serializer.read_data()

# Get MIME type for proper file upload (API requires content-type)
mime_type = DataTypeSerializer.get_mime_type(image_path)
if not mime_type:
# Default to PNG if MIME type cannot be determined
mime_type = "image/png"

# Create file tuple with filename and MIME type for OpenAI SDK
# Format: (filename, content, content_type)
filename = os.path.basename(image_path)
input_file = (filename, image_bytes, mime_type)
Comment on lines +182 to +199
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is a bit too long, way more than 20 lines. Can we split this branch and move some of the logic out to specialized functions? e.g.

async def _prepare_image_input_async(self, *, image_path: str) -> tuple[str, bytes, str]:
    image_serializer = data_serializer_factory(
        value=image_path, data_type="image_path", category="prompt-memory-entries"
    )
    image_bytes = await image_serializer.read_data()
    
    mime_type = DataTypeSerializer.get_mime_type(image_path) or "image/png"
    filename = os.path.basename(image_path)
    
    return (filename, image_bytes, mime_type)


response = await self._handle_openai_request(
api_call=lambda: self._async_client.videos.create_and_poll(
model=self._model_name,
prompt=prompt,
size=self._size, # type: ignore[arg-type]
seconds=str(self._n_seconds), # type: ignore[arg-type]
Comment on lines +205 to +206
Copy link
Contributor

@bashirpartovi bashirpartovi Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not use #type: ignore. This is usually a code-smell that is being masked. Looking at OpenAI SDK it defines (VideoSeconds):

VideoSize: Literal["720x1280", "1280x720", "1024x1792", "1792x1024"]
VideoSeconds: Literal["4", "8", "12"] # these are strings

This code has 2 issues:

  1. self._size is typed as str, which is too broad
  2. self._n_seconds is an int, but the SDK expects Literal["4", "8", "12"] (string literals)

I think you should follow the established patterns we have, e.g. openai_tts_target.py and openai_chat_audio_config.py.

VideoSize = Literal["720x1280", "1280x720", "1024x1792", "1792x1024"]
VideoSeconds = Literal["4", "8", "12"]


class OpenAIVideoTarget(OpenAITarget):
    
    SUPPORTED_RESOLUTIONS: list[VideoSize] = ["720x1280", "1280x720", "1024x1792", "1792x1024"]
    SUPPORTED_DURATIONS: list[VideoSeconds] = ["4", "8", "12"]

    def __init__(
        self,
        *,
        resolution_dimensions: VideoSize = "1280x720",
        n_seconds: VideoSeconds = "4",  
        **kwargs: Any,
    ) -> None:
        ...
        self._n_seconds: VideoSeconds = n_seconds
        self._size: VideoSize = resolution_dimensions

Then you don't need # type: ignore for your SDK call. Although I understand that this is a breaking API change. If we want it to be non-breaking change, then you'd need to do something like this:

def __init__(
    self,
    *,
    resolution_dimensions: VideoSize = "1280x720",
    n_seconds: int | VideoSeconds = 4,  # you can accept both here
    **kwargs: Any,
) -> None:
    ...
    self._n_seconds: VideoSeconds = str(n_seconds) if isinstance(n_seconds, int) else n_seconds
    self._validate_duration() 

input_reference=input_file,
),
request=message,
)
else:
# TEXT-TO-VIDEO MODE: Standard generation
response = await self._handle_openai_request(
api_call=lambda: self._async_client.videos.create_and_poll(
model=self._model_name,
prompt=prompt,
size=self._size, # type: ignore[arg-type]
seconds=str(self._n_seconds), # type: ignore[arg-type]
),
request=message,
)

return [response]

async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any:
"""
Create a remix of an existing video and poll until complete.

The OpenAI SDK's remix() method returns immediately with a job status.
This method polls until the job completes or fails.

Args:
video_id: The ID of the completed video to remix.
prompt: The text prompt directing the remix.

Returns:
The completed Video object from the OpenAI SDK.
"""
video = await self._async_client.videos.remix(video_id, prompt=prompt)

# Poll until completion if not already done
if video.status not in ["completed", "failed"]:
video = await self._async_client.videos.poll(video.id)

return video

def _check_content_filter(self, response: Any) -> bool:
"""
Check if a video generation response was content filtered.
Expand Down Expand Up @@ -218,13 +285,17 @@ async def _construct_message_from_response(self, response: Any, request: Any) ->
if video.status == "completed":
logger.info(f"Video generation completed successfully: {video.id}")

# Log remix metadata if available
if hasattr(video, "remixed_from_video_id") and video.remixed_from_video_id:
logger.info(f"Video was remixed from: {video.remixed_from_video_id}")

# Download video content using SDK
video_response = await self._async_client.videos.download_content(video.id)
# Extract bytes from HttpxBinaryResponseContent
video_content = video_response.content

# Save the video to storage
return await self._save_video_response(request=request, video_data=video_content)
# Save the video to storage (include video.id for chaining remixes)
return await self._save_video_response(request=request, video_data=video_content, video_id=video.id)

elif video.status == "failed":
# Handle failed video generation (non-content-filter)
Expand All @@ -249,13 +320,16 @@ async def _construct_message_from_response(self, response: Any, request: Any) ->
error="unknown",
)

async def _save_video_response(self, *, request: MessagePiece, video_data: bytes) -> Message:
async def _save_video_response(
self, *, request: MessagePiece, video_data: bytes, video_id: Optional[str] = None
) -> Message:
"""
Save video data to storage and construct response.

Args:
request: The original request message piece.
video_data: The video content as bytes.
video_id: The video ID from the API (stored in metadata for chaining remixes).

Returns:
Message: The response with the video file path.
Expand All @@ -267,11 +341,15 @@ async def _save_video_response(self, *, request: MessagePiece, video_data: bytes

logger.info(f"Video saved to: {video_path}")

# Include video_id in metadata for chaining (e.g., remix the generated video later)
prompt_metadata = {"video_id": video_id} if video_id else None

# Construct response
response_entry = construct_response_from_request(
request=request,
response_text_pieces=[video_path],
response_type="video_path",
prompt_metadata=prompt_metadata,
)

return response_entry
Expand All @@ -280,19 +358,45 @@ def _validate_request(self, *, message: Message) -> None:
"""
Validate the request message.

Accepts:
- Single text piece (text-to-video or remix mode)
- Text piece + image_path piece (image-to-video mode)

Args:
message: The message to validate.

Raises:
ValueError: If the request is invalid.
"""
n_pieces = len(message.message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.")

piece_type = message.message_pieces[0].converted_value_data_type
if piece_type != "text":
raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.")
pieces = message.message_pieces
n_pieces = len(pieces)

if n_pieces == 0:
raise ValueError("Message must contain at least one piece.")

# Categorize pieces
text_pieces = [p for p in pieces if p.converted_value_data_type == "text"]
image_pieces = [p for p in pieces if p.converted_value_data_type == "image_path"]
other_pieces = [p for p in pieces if p.converted_value_data_type not in ("text", "image_path")]
Comment on lines +378 to +380
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the function I suggested above here:

text_pieces = message.get_pieces_by_type(data_type="text")
image_pieces = message.get_pieces_by_type(data_type="image_path")
other_pieces = message.get_pieces_by_type(exclude_types={"text", "image_path"})


# Must have exactly one text piece
if len(text_pieces) != 1:
raise ValueError(f"Expected exactly 1 text piece, got {len(text_pieces)}.")

# At most one image piece
if len(image_pieces) > 1:
raise ValueError(f"Expected at most 1 image piece, got {len(image_pieces)}.")

# No other data types allowed
if other_pieces:
types = [p.converted_value_data_type for p in other_pieces]
raise ValueError(f"Unsupported piece types: {types}. Only 'text' and 'image_path' are supported.")

# Check for conflicting modes: remix + image
text_piece = text_pieces[0]
remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None
if remix_video_id and image_pieces:
raise ValueError("Cannot use image input in remix mode. Remix uses existing video as reference.")

def is_json_response_supported(self) -> bool:
"""
Expand Down
Loading
Loading