Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions netra/simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from netra.simulation.models import (
ConversationResponse,
ConversationStatus,
FileData,
ProcessedFile,
SimulationItem,
TaskResult,
)
Expand All @@ -12,6 +14,8 @@
"BaseTask",
"ConversationResponse",
"ConversationStatus",
"FileData",
"ProcessedFile",
"SimulationItem",
"TaskResult",
]
8 changes: 6 additions & 2 deletions netra/simulation/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from netra.config import Config
from netra.simulation.client import SimulationHttpClient
from netra.simulation.models import SimulationItem
from netra.simulation.models import FileData, SimulationItem
from netra.simulation.task import BaseTask
from netra.simulation.utils import (
execute_task,
Expand Down Expand Up @@ -197,6 +197,7 @@ async def _execute_conversation(
run_item_id = run_item.run_item_id
message = run_item.message
turn_id = run_item.turn_id
raw_files: list[FileData] = run_item.files
session_id: Optional[str] = None

while True:
Expand All @@ -208,7 +209,9 @@ async def _execute_conversation(
span_context = otel_span.get_span_context()
trace_id = format_trace_id(span_context.trace_id)

response_message, task_session_id = await execute_task(task, message, session_id)
response_message, task_session_id = await execute_task(
task, message, session_id, raw_files=raw_files
)
if task_session_id:
session_id = task_session_id

Expand Down Expand Up @@ -243,6 +246,7 @@ async def _execute_conversation(

message = response.next_user_message # type:ignore[assignment]
turn_id = response.next_turn_id # type:ignore[assignment]
raw_files = response.next_files

except Exception as exc:
error_msg = str(exc)
Expand Down
39 changes: 38 additions & 1 deletion netra/simulation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import httpx

from netra.config import Config
from netra.simulation.models import ConversationResponse, SimulationItem
from netra.simulation.models import ConversationResponse, FileData, SimulationItem

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,6 +148,7 @@ def create_run(
run_item_id=msg.get("testRunItemId", ""),
message=msg.get("userMessage", ""),
turn_id=msg.get("turnId", ""),
files=self._parse_files(msg.get("files")),
)
for msg in user_messages
]
Expand Down Expand Up @@ -217,6 +218,7 @@ def trigger_conversation(
next_turn_id=next_msg.get("turnId", ""),
next_user_message=next_msg.get("userMessage", ""),
next_run_item_id=next_msg.get("testRunItemId", ""),
next_files=self._parse_files(next_msg.get("files")),
)

except Exception as exc:
Expand Down Expand Up @@ -277,6 +279,41 @@ def post_run_status(self, run_id: str, status: str) -> Any:
logger.error("%s: Failed to post run status for run '%s': %s", _LOG_PREFIX, run_id, error_msg)
return {"success": False}

@staticmethod
def _parse_files(raw_files: Any) -> list[FileData]:
"""Parse raw file entries from the backend response into FileData objects.

Args:
raw_files: List of file dictionaries from the JSON response, or None.

Returns:
List of FileData objects. Malformed entries are skipped.
"""
if not raw_files or not isinstance(raw_files, list):
return []

parsed: list[FileData] = []
for entry in raw_files:
if not isinstance(entry, dict):
continue
file_name = entry.get("fileName", "")
download_url = entry.get("downloadUrl", "")
if not file_name or not download_url:
logger.warning(
"%s: Skipping file entry with missing fileName or downloadUrl",
_LOG_PREFIX,
)
continue
parsed.append(
FileData(
file_name=file_name,
content_type=entry.get("contentType", ""),
description=entry.get("description"),
download_url=download_url,
)
)
return parsed

def _extract_error_message(
self,
response: Optional[httpx.Response],
Expand Down
40 changes: 39 additions & 1 deletion netra/simulation/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Data models for the simulation module."""

from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

Expand All @@ -12,6 +12,40 @@ class ConversationStatus(Enum):
STOP = "stop"


@dataclass(slots=True, frozen=True)
class FileData:
"""Raw file metadata received from the backend.

Attributes:
file_name: Name of the file.
content_type: MIME type of the file content.
description: Optional description of the file.
download_url: Pre-signed URL to download the file.
"""

file_name: str
content_type: str
description: Optional[str]
download_url: str


@dataclass(slots=True, frozen=True)
class ProcessedFile:
"""File after download and base64 encoding, delivered to the user task.

Attributes:
file_name: Name of the file.
content_type: MIME type of the file content.
description: Optional description of the file.
data: Base64-encoded file content.
"""

file_name: str
content_type: str
description: Optional[str]
data: str


@dataclass(slots=True, frozen=True)
class SimulationItem:
"""Represents a single item in a simulation run.
Expand All @@ -20,11 +54,13 @@ class SimulationItem:
run_item_id: Unique identifier for the run item.
message: The user message content.
turn_id: Identifier for the conversation turn.
files: File metadata attached to this item.
"""

run_item_id: str
message: str
turn_id: str
files: list[FileData] = field(default_factory=list)


@dataclass(slots=True)
Expand All @@ -37,13 +73,15 @@ class ConversationResponse:
next_turn_id: Identifier for the next turn if continuing.
next_user_message: The next user message if continuing.
next_run_item_id: Identifier for the next run item if continuing.
next_files: File metadata for the next turn if continuing.
"""

decision: str
reason: Optional[str] = None
next_turn_id: Optional[str] = None
next_user_message: Optional[str] = None
next_run_item_id: Optional[str] = None
next_files: list[FileData] = field(default_factory=list)


@dataclass(slots=True, frozen=True)
Expand Down
39 changes: 35 additions & 4 deletions netra/simulation/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from abc import ABC, abstractmethod
from typing import Awaitable, Optional
from typing import Any, Awaitable, Optional

from netra.simulation.models import TaskResult

Expand All @@ -18,8 +18,9 @@ class BaseTask(ABC):
Subclasses must:
- Implement run(): Executes the task logic and returns a TaskResult.

The run method receives a message and optional session_id, and must return
a TaskResult containing the response message and session_id.
The base contract requires ``message`` and ``session_id``. Subclasses that
need file attachments can add a ``files`` keyword argument — the framework
detects it via introspection and will pass downloaded files automatically.

Example:
class MyTask(BaseTask):
Expand All @@ -37,6 +38,24 @@ def run(self, message: str, session_id: Optional[str] = None) -> TaskResult:
task=MyTask(),
)

Example with file uploads:
class MyFileTask(BaseTask):
def run(
self,
message: str,
session_id: Optional[str] = None,
files: Optional[list[ProcessedFile]] = None,
) -> TaskResult:
# Access base64-encoded file data
if files:
for f in files:
print(f.file_name, f.content_type, len(f.data))
response = my_agent.chat(message, session_id=session_id, files=files)
return TaskResult(
message=response.text,
session_id=response.session_id or session_id or "default",
)

Async Example:
class MyAsyncTask(BaseTask):
async def run(self, message: str, session_id: Optional[str] = None) -> TaskResult:
Expand All @@ -49,17 +68,29 @@ async def run(self, message: str, session_id: Optional[str] = None) -> TaskResul
"""

@abstractmethod
def run(self, message: str, session_id: Optional[str] = None) -> TaskResult | Awaitable[TaskResult]:
def run(
self,
message: str,
session_id: Optional[str] = None,
**kwargs: Any,
) -> TaskResult | Awaitable[TaskResult]:
"""
Execute the task logic.

This method can be sync or async. If async, the framework will
await the coroutine automatically.

The base signature requires only ``message`` and ``session_id``.
Subclasses that handle file attachments should declare an additional
``files: Optional[list[ProcessedFile]] = None`` parameter — the
framework will supply it automatically when the dataset item includes
file attachments.

Args:
message: The input message from the simulation.
session_id: Optional session identifier for conversation continuity.
Will be None for the first turn of a conversation.
**kwargs: Reserved for forward-compatible extensions (e.g. ``files``).

Returns:
TaskResult: The task result containing:
Expand Down
Loading