From 61df3c090a2e3982b5891b2eb5f3c6e7e346b775 Mon Sep 17 00:00:00 2001 From: akash-vijay-kv Date: Mon, 18 May 2026 10:14:54 +0530 Subject: [PATCH 1/2] [NET-856] feat: Add file handling support in the simulation workflow --- netra/simulation/__init__.py | 4 ++ netra/simulation/api.py | 8 ++- netra/simulation/client.py | 39 ++++++++++++- netra/simulation/models.py | 40 ++++++++++++- netra/simulation/task.py | 39 +++++++++++-- netra/simulation/utils.py | 107 ++++++++++++++++++++++++++++++++++- 6 files changed, 227 insertions(+), 10 deletions(-) diff --git a/netra/simulation/__init__.py b/netra/simulation/__init__.py index 79c7cfd..efcb9d0 100644 --- a/netra/simulation/__init__.py +++ b/netra/simulation/__init__.py @@ -2,6 +2,8 @@ from netra.simulation.models import ( ConversationResponse, ConversationStatus, + FileData, + ProcessedFile, SimulationItem, TaskResult, ) @@ -12,6 +14,8 @@ "BaseTask", "ConversationResponse", "ConversationStatus", + "FileData", + "ProcessedFile", "SimulationItem", "TaskResult", ] diff --git a/netra/simulation/api.py b/netra/simulation/api.py index dfea287..5340875 100644 --- a/netra/simulation/api.py +++ b/netra/simulation/api.py @@ -8,7 +8,7 @@ from netra.config import Config from netra.simulation.client import SimulationHttpClient -from netra.simulation.models import SimulationItem +from netra.simulation.models import FileData, SimulationItem from netra.simulation.task import BaseTask from netra.simulation.utils import ( execute_task, @@ -197,6 +197,7 @@ async def _execute_conversation( run_item_id = run_item.run_item_id message = run_item.message turn_id = run_item.turn_id + raw_files: list[FileData] = run_item.files session_id: Optional[str] = None while True: @@ -208,7 +209,9 @@ async def _execute_conversation( span_context = otel_span.get_span_context() trace_id = format_trace_id(span_context.trace_id) - response_message, task_session_id = await execute_task(task, message, session_id) + response_message, task_session_id = await execute_task( + task, message, session_id, raw_files=raw_files + ) if task_session_id: session_id = task_session_id @@ -243,6 +246,7 @@ async def _execute_conversation( message = response.next_user_message # type:ignore[assignment] turn_id = response.next_turn_id # type:ignore[assignment] + raw_files = response.next_files except Exception as exc: error_msg = str(exc) diff --git a/netra/simulation/client.py b/netra/simulation/client.py index d495185..b9f1322 100644 --- a/netra/simulation/client.py +++ b/netra/simulation/client.py @@ -7,7 +7,7 @@ import httpx from netra.config import Config -from netra.simulation.models import ConversationResponse, SimulationItem +from netra.simulation.models import ConversationResponse, FileData, SimulationItem logger = logging.getLogger(__name__) @@ -148,6 +148,7 @@ def create_run( run_item_id=msg.get("testRunItemId", ""), message=msg.get("userMessage", ""), turn_id=msg.get("turnId", ""), + files=self._parse_files(msg.get("files")), ) for msg in user_messages ] @@ -217,6 +218,7 @@ def trigger_conversation( next_turn_id=next_msg.get("turnId", ""), next_user_message=next_msg.get("userMessage", ""), next_run_item_id=next_msg.get("testRunItemId", ""), + next_files=self._parse_files(next_msg.get("files")), ) except Exception as exc: @@ -277,6 +279,41 @@ def post_run_status(self, run_id: str, status: str) -> Any: logger.error("%s: Failed to post run status for run '%s': %s", _LOG_PREFIX, run_id, error_msg) return {"success": False} + @staticmethod + def _parse_files(raw_files: Any) -> list[FileData]: + """Parse raw file entries from the backend response into FileData objects. + + Args: + raw_files: List of file dictionaries from the JSON response, or None. + + Returns: + List of FileData objects. Malformed entries are skipped. + """ + if not raw_files or not isinstance(raw_files, list): + return [] + + parsed: list[FileData] = [] + for entry in raw_files: + if not isinstance(entry, dict): + continue + file_name = entry.get("fileName", "") + download_url = entry.get("downloadUrl", "") + if not file_name or not download_url: + logger.warning( + "%s: Skipping file entry with missing fileName or downloadUrl", + _LOG_PREFIX, + ) + continue + parsed.append( + FileData( + file_name=file_name, + content_type=entry.get("contentType", ""), + description=entry.get("description"), + download_url=download_url, + ) + ) + return parsed + def _extract_error_message( self, response: Optional[httpx.Response], diff --git a/netra/simulation/models.py b/netra/simulation/models.py index 258f655..570a765 100644 --- a/netra/simulation/models.py +++ b/netra/simulation/models.py @@ -1,6 +1,6 @@ """Data models for the simulation module.""" -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Optional @@ -12,6 +12,40 @@ class ConversationStatus(Enum): STOP = "stop" +@dataclass(slots=True, frozen=True) +class FileData: + """Raw file metadata received from the backend. + + Attributes: + file_name: Name of the file. + content_type: MIME type of the file content. + description: Optional description of the file. + download_url: Pre-signed URL to download the file. + """ + + file_name: str + content_type: str + description: Optional[str] + download_url: str + + +@dataclass(slots=True, frozen=True) +class ProcessedFile: + """File after download and base64 encoding, delivered to the user task. + + Attributes: + file_name: Name of the file. + content_type: MIME type of the file content. + description: Optional description of the file. + data: Base64-encoded file content. + """ + + file_name: str + content_type: str + description: Optional[str] + data: str + + @dataclass(slots=True, frozen=True) class SimulationItem: """Represents a single item in a simulation run. @@ -20,11 +54,13 @@ class SimulationItem: run_item_id: Unique identifier for the run item. message: The user message content. turn_id: Identifier for the conversation turn. + files: File metadata attached to this item. """ run_item_id: str message: str turn_id: str + files: list[FileData] = field(default_factory=list) @dataclass(slots=True) @@ -37,6 +73,7 @@ class ConversationResponse: next_turn_id: Identifier for the next turn if continuing. next_user_message: The next user message if continuing. next_run_item_id: Identifier for the next run item if continuing. + next_files: File metadata for the next turn if continuing. """ decision: str @@ -44,6 +81,7 @@ class ConversationResponse: next_turn_id: Optional[str] = None next_user_message: Optional[str] = None next_run_item_id: Optional[str] = None + next_files: list[FileData] = field(default_factory=list) @dataclass(slots=True, frozen=True) diff --git a/netra/simulation/task.py b/netra/simulation/task.py index bdfc39b..126630d 100644 --- a/netra/simulation/task.py +++ b/netra/simulation/task.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import Awaitable, Optional +from typing import Any, Awaitable, Optional from netra.simulation.models import TaskResult @@ -18,8 +18,9 @@ class BaseTask(ABC): Subclasses must: - Implement run(): Executes the task logic and returns a TaskResult. - The run method receives a message and optional session_id, and must return - a TaskResult containing the response message and session_id. + The base contract requires ``message`` and ``session_id``. Subclasses that + need file attachments can add a ``files`` keyword argument — the framework + detects it via introspection and will pass downloaded files automatically. Example: class MyTask(BaseTask): @@ -37,6 +38,24 @@ def run(self, message: str, session_id: Optional[str] = None) -> TaskResult: task=MyTask(), ) + Example with file uploads: + class MyFileTask(BaseTask): + def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> TaskResult: + # Access base64-encoded file data + if files: + for f in files: + print(f.file_name, f.content_type, len(f.data)) + response = my_agent.chat(message, session_id=session_id, files=files) + return TaskResult( + message=response.text, + session_id=response.session_id or session_id or "default", + ) + Async Example: class MyAsyncTask(BaseTask): async def run(self, message: str, session_id: Optional[str] = None) -> TaskResult: @@ -49,17 +68,29 @@ async def run(self, message: str, session_id: Optional[str] = None) -> TaskResul """ @abstractmethod - def run(self, message: str, session_id: Optional[str] = None) -> TaskResult | Awaitable[TaskResult]: + def run( + self, + message: str, + session_id: Optional[str] = None, + **kwargs: Any, + ) -> TaskResult | Awaitable[TaskResult]: """ Execute the task logic. This method can be sync or async. If async, the framework will await the coroutine automatically. + The base signature requires only ``message`` and ``session_id``. + Subclasses that handle file attachments should declare an additional + ``files: Optional[list[ProcessedFile]] = None`` parameter — the + framework will supply it automatically when the dataset item includes + file attachments. + Args: message: The input message from the simulation. session_id: Optional session identifier for conversation continuity. Will be None for the first turn of a conversation. + **kwargs: Reserved for forward-compatible extensions (e.g. ``files``). Returns: TaskResult: The task result containing: diff --git a/netra/simulation/utils.py b/netra/simulation/utils.py index 341cd26..4ce62ca 100644 --- a/netra/simulation/utils.py +++ b/netra/simulation/utils.py @@ -1,17 +1,25 @@ """Utility functions for the simulation module.""" import asyncio +import base64 +import inspect import logging +import os import threading from typing import Awaitable, Optional, Tuple, TypeVar -from netra.simulation.models import TaskResult +import httpx + +from netra.simulation.models import FileData, ProcessedFile, TaskResult from netra.simulation.task import BaseTask logger = logging.getLogger(__name__) T = TypeVar("T") +_LOG_PREFIX = "netra.simulation" +_DEFAULT_FILE_DOWNLOAD_TIMEOUT = 30.0 + def format_trace_id(trace_id: int) -> str: """Format the trace ID as a 32-digit hexadecimal string. @@ -88,17 +96,108 @@ def runner() -> None: return asyncio.run(coro) # type: ignore[arg-type] +def _get_file_download_timeout() -> float: + """Get file download timeout from environment or use default. + + Returns: + The timeout value in seconds. + """ + timeout_str = os.getenv("NETRA_SIMULATION_FILE_DOWNLOAD_TIMEOUT") + if not timeout_str: + return _DEFAULT_FILE_DOWNLOAD_TIMEOUT + try: + return float(timeout_str) + except ValueError: + logger.warning( + "%s: Invalid file download timeout '%s', using default %.1f", + _LOG_PREFIX, + timeout_str, + _DEFAULT_FILE_DOWNLOAD_TIMEOUT, + ) + return _DEFAULT_FILE_DOWNLOAD_TIMEOUT + + +def process_files(files: list[FileData]) -> list[ProcessedFile]: + """Download files from pre-signed URLs and base64-encode their content. + + Each file is downloaded individually. If any file fails to download, the + entire batch is aborted with a ``RuntimeError`` so that file-aware tasks + never receive a partial file list. + + Args: + files: List of FileData objects containing download URLs. + + Returns: + List of ProcessedFile objects with base64-encoded data. + + Raises: + RuntimeError: If a file download or encoding fails. + """ + if not files: + return [] + + timeout = _get_file_download_timeout() + processed: list[ProcessedFile] = [] + + for file_data in files: + try: + response = httpx.get(file_data.download_url, timeout=timeout) + response.raise_for_status() + encoded = base64.b64encode(response.content).decode("ascii") + processed.append( + ProcessedFile( + file_name=file_data.file_name, + content_type=file_data.content_type, + description=file_data.description, + data=encoded, + ) + ) + except Exception as exc: + raise RuntimeError(f"Failed to download file '{file_data.file_name}': {exc}") from exc + + return processed + + +def _task_accepts_files(task: BaseTask) -> bool: + """Check whether the task's run() method accepts a 'files' parameter. + + Used for backward compatibility so that existing BaseTask subclasses that + do not declare the files parameter are not broken. + + Args: + task: The BaseTask instance to inspect. + + Returns: + True if the run() method has a 'files' parameter or **kwargs. + """ + try: + sig = inspect.signature(task.run) + params = sig.parameters + if "files" in params: + return True + return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + except (ValueError, TypeError): + return False + + async def execute_task( task: BaseTask, message: str, session_id: Optional[str], + raw_files: Optional[list[FileData]] = None, ) -> Tuple[str, Optional[str]]: """Execute a task's run method (sync or async) and extract message and session_id. + Files are only downloaded and base64-encoded when the task's run() method + actually accepts a ``files`` parameter, avoiding unnecessary network I/O + for legacy tasks. + Args: task: The BaseTask instance to execute. message: The input message to pass to the task. session_id: The current session identifier. + raw_files: Raw file metadata from the backend. Downloads are deferred + until we confirm the task can accept them. Returns: A tuple of (response_message, session_id). @@ -106,7 +205,11 @@ async def execute_task( Raises: ValueError: If the task returns an unsupported type. """ - result = task.run(message=message, session_id=session_id) + kwargs: dict[str, object] = {"message": message, "session_id": session_id} + if raw_files and _task_accepts_files(task): + kwargs["files"] = process_files(raw_files) + + result = task.run(**kwargs) # type: ignore[arg-type] if asyncio.iscoroutine(result): result = await result From 80e3c1554ece7458340097f858b1ecf62fb574eb Mon Sep 17 00:00:00 2001 From: akash-vijay-kv Date: Tue, 19 May 2026 09:30:17 +0530 Subject: [PATCH 2/2] chore: Rename the file response key to attachments --- netra/simulation/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netra/simulation/client.py b/netra/simulation/client.py index b9f1322..77e9b5f 100644 --- a/netra/simulation/client.py +++ b/netra/simulation/client.py @@ -148,7 +148,7 @@ def create_run( run_item_id=msg.get("testRunItemId", ""), message=msg.get("userMessage", ""), turn_id=msg.get("turnId", ""), - files=self._parse_files(msg.get("files")), + files=self._parse_files(msg.get("attachments")), ) for msg in user_messages ] @@ -218,7 +218,7 @@ def trigger_conversation( next_turn_id=next_msg.get("turnId", ""), next_user_message=next_msg.get("userMessage", ""), next_run_item_id=next_msg.get("testRunItemId", ""), - next_files=self._parse_files(next_msg.get("files")), + next_files=self._parse_files(next_msg.get("attachments")), ) except Exception as exc: