Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 85 additions & 7 deletions src/rapidata/rapidata_client/dataset/_rapidata_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def add_datapoints(
executor = ThreadPoolExecutor(max_workers=rapidata_config.upload.maxWorkers)

# 3. Execute uploads and incremental datapoint creation
self._execute_incremental_creation(
asset_failures = self._execute_incremental_creation(
datapoints,
asset_to_datapoints,
datapoint_pending_count,
Expand All @@ -105,7 +105,12 @@ def add_datapoints(

# 5. Collect and return results
return self._collect_and_return_results(
datapoints, creation_futures, datapoint_pending_count, lock
datapoints,
creation_futures,
datapoint_pending_count,
lock,
asset_failures,
asset_to_datapoints,
)

def _build_asset_to_datapoint_mapping(
Expand Down Expand Up @@ -157,7 +162,7 @@ def _execute_incremental_creation(
creation_futures: list[tuple[int, Future]],
lock: threading.Lock,
executor: ThreadPoolExecutor,
) -> None:
) -> list[FailedUpload[str]]:
"""
Execute asset uploads and incremental datapoint creation.

Expand All @@ -168,7 +173,12 @@ def _execute_incremental_creation(
creation_futures: List to store creation futures.
lock: Lock protecting shared state.
executor: Thread pool executor for datapoint creation.

Returns:
Asset-level failures from the upload phase, so callers can map them
back to the datapoints they blocked.
"""
asset_failures: list[FailedUpload[str]] = []
# Create progress bar for datapoint creation
datapoint_pbar = tqdm(
total=len(datapoints),
Expand Down Expand Up @@ -225,6 +235,8 @@ def _execute_incremental_creation(
# Always close progress bar, even on exception
datapoint_pbar.close()

return asset_failures

def _create_asset_completion_callback(
self,
datapoints: list[Datapoint],
Expand Down Expand Up @@ -369,6 +381,8 @@ def _collect_and_return_results(
creation_futures: list[tuple[int, Future]],
datapoint_pending_count: dict[int, int],
lock: threading.Lock,
asset_failures: list[FailedUpload[str]],
asset_to_datapoints: dict[str, set[int]],
) -> tuple[list[Datapoint], list[FailedUpload[Datapoint]]]:
"""
Collect results from datapoint creation tasks.
Expand All @@ -378,6 +392,10 @@ def _collect_and_return_results(
creation_futures: List of creation futures.
datapoint_pending_count: Datapoints whose assets failed.
lock: Lock protecting datapoint_pending_count.
asset_failures: Asset-level failures from the upload phase, used to
attach the underlying error reasons and trace IDs to the
datapoint-level FailedUpload entries.
asset_to_datapoints: Mapping from asset to datapoint indices.

Returns:
Tuple of (successful_uploads, failed_uploads).
Expand All @@ -395,15 +413,25 @@ def _collect_and_return_results(
# Use from_exception to extract proper error reason from RapidataError
failed_uploads.append(FailedUpload.from_exception(datapoints[idx], e))

# Build reverse mapping: datapoint index -> the asset-level failures
# that blocked it. A single datapoint can have multiple required
# assets fail, so we accumulate all of them.
datapoint_to_asset_failures: dict[int, list[FailedUpload[str]]] = {}
for asset_failure in asset_failures:
affected = asset_to_datapoints.get(asset_failure.item, set())
for dp_idx in affected:
datapoint_to_asset_failures.setdefault(dp_idx, []).append(
asset_failure
)

# Handle datapoints whose assets failed to upload
with lock:
for idx in datapoint_pending_count:
logger.warning(f"Datapoint {idx} assets failed to upload")
failed_uploads.append(
FailedUpload(
item=datapoints[idx],
error_type="AssetUploadFailed",
error_message="One or more required assets failed to upload",
self._build_asset_failure_for_datapoint(
datapoints[idx],
datapoint_to_asset_failures.get(idx, []),
)
)

Expand All @@ -412,6 +440,56 @@ def _collect_and_return_results(
)
return successful_uploads, failed_uploads

@staticmethod
def _build_asset_failure_for_datapoint(
datapoint: Datapoint,
related_asset_failures: list[FailedUpload[str]],
) -> FailedUpload[Datapoint]:
"""
Build a datapoint-level FailedUpload from the asset failures that
blocked it.

The error message reuses the underlying asset errors' reasons so the
user sees the real cause (not just "assets failed"), and the trace IDs
from any RapidataError-backed failures are aggregated so each blocked
datapoint surfaces every backend trace involved.
"""
if not related_asset_failures:
return FailedUpload(
item=datapoint,
error_type="AssetUploadFailed",
error_message="One or more required assets failed to upload",
)

# Deduplicate reasons while preserving order so the message is stable.
seen_reasons: set[str] = set()
unique_reasons: list[str] = []
for fu in related_asset_failures:
if fu.error_message not in seen_reasons:
seen_reasons.add(fu.error_message)
unique_reasons.append(fu.error_message)

if len(unique_reasons) == 1:
error_message = f"Asset upload failed: {unique_reasons[0]}"
else:
error_message = "Asset upload failed: " + "; ".join(unique_reasons)

seen_trace_ids: set[str] = set()
unique_trace_ids: list[str] = []
for fu in related_asset_failures:
if fu.trace_id and fu.trace_id not in seen_trace_ids:
seen_trace_ids.add(fu.trace_id)
unique_trace_ids.append(fu.trace_id)

trace_id = ", ".join(unique_trace_ids) if unique_trace_ids else None

return FailedUpload(
item=datapoint,
error_type="AssetUploadFailed",
error_message=error_message,
trace_id=trace_id,
)

def _create_dataset_groups(self, datapoints: list[Datapoint]) -> None:
"""Create dataset groups from datapoints that have a group field."""
from rapidata.api_client.models.create_dataset_group_endpoint_input import (
Expand Down
16 changes: 15 additions & 1 deletion src/rapidata/rapidata_client/exceptions/failed_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,25 @@ class FailedUpload(Generic[T]):
error_type: The type of the exception (e.g., "RapidataError").
timestamp: Optional timestamp when the failure occurred.
exception: Optional original exception for richer error context.
trace_id: Optional backend trace ID, when the failure originated from a
RapidataError whose response included a traceId. Used to correlate
an SDK-side failure with the backend trace that produced it.
"""

item: T
error_message: str
error_type: str
timestamp: Optional[datetime] = field(default_factory=datetime.now)
exception: Optional[Exception] = None
trace_id: Optional[str] = None

@classmethod
def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T]:
"""
Create a FailedUpload from an item and exception.

For RapidataError exceptions, extracts the clean API error reason.
For RapidataError exceptions, extracts the clean API error reason and
the backend trace ID (when present in the error response).
For other exceptions, uses the string representation.

Args:
Expand All @@ -54,9 +59,14 @@ def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T]
from rapidata.rapidata_client.exceptions.rapidata_error import RapidataError

error_type = type(exception).__name__
trace_id: Optional[str] = None

if isinstance(exception, RapidataError):
error_message = exception.get_reason()
if isinstance(exception.details, dict):
raw_trace_id = exception.details.get("traceId")
if isinstance(raw_trace_id, str) and raw_trace_id:
trace_id = raw_trace_id
else:
error_message = str(exception)

Expand All @@ -65,6 +75,7 @@ def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T]
error_message=error_message,
error_type=error_type,
exception=exception,
trace_id=trace_id,
)

def format_error_details(self) -> str:
Expand All @@ -80,6 +91,9 @@ def format_error_details(self) -> str:
f"Error Message: {self.error_message}",
]

if self.trace_id:
details.append(f"Trace Id: {self.trace_id}")

if self.timestamp:
details.append(f"Timestamp: {self.timestamp.isoformat()}")

Expand Down
16 changes: 13 additions & 3 deletions src/rapidata/rapidata_client/exceptions/failed_upload_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,20 @@ def __str__(self) -> str:

lines = [f"{total} datapoint(s) failed to upload:"]

for reason, datapoints in self.failures_by_reason.items():
# Group internally on the full FailedUpload so each item can carry its
# own trace ID (the public failures_by_reason groups by reason only and
# discards that detail).
grouped: dict[str, list[FailedUpload[Datapoint]]] = defaultdict(list)
for fu in self._failed_uploads:
grouped[fu.error_message].append(fu)

for reason, failures in grouped.items():
lines.append(f" '{reason}': [")
for dp in datapoints:
lines.append(f" {dp},")
for fu in failures:
if fu.trace_id:
lines.append(f" {fu.item} [trace_id={fu.trace_id}],")
else:
lines.append(f" {fu.item},")
lines.append(" ]")

failed_upload_message = "\n".join(lines)
Expand Down
Loading