Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 82 additions & 24 deletions src/rapidata/rapidata_client/rapidata_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from __future__ import annotations

import json
import os
import threading
import time
from dataclasses import dataclass
from typing import Any
import requests
from packaging import version
Expand Down Expand Up @@ -37,6 +42,19 @@
)


# Cache userinfo process-wide so request bursts that spin up many short-lived
# RapidataClient instances don't hammer identity-service with redundant calls.
@dataclass
class _UserInfoCacheEntry:
result: dict[str, Any]
expires_at: float


_USERINFO_CACHE_TTL_SECONDS = 24 * 60 * 60
_userinfo_cache: dict[tuple[str, str], _UserInfoCacheEntry] = {}
_userinfo_cache_lock = threading.Lock()


class RapidataClient:
"""The Rapidata client is the main entry point for interacting with the Rapidata API. It allows you to create orders and validation sets."""

Expand Down Expand Up @@ -104,6 +122,9 @@ def __init__(
if environment is None:
environment = os.environ.get("RAPIDATA_ENVIRONMENT") or "rapidata.ai"

self._client_id = client_id
self._environment = environment

with tracer.start_as_current_span("RapidataClient.__init__"):
logger.debug("Checking version")
self._check_version()
Expand Down Expand Up @@ -156,17 +177,65 @@ def reset_credentials(self):
"""Reset the credentials saved in the configuration file for the current environment."""
logger.info("Resetting credentials")
self._openapi_service.reset_credentials()
if self._client_id is not None:
with _userinfo_cache_lock:
_userinfo_cache.pop((self._environment, self._client_id), None)
logger.info("Credentials reset")

def clear_all_caches(self):
"""Clear all caches for the client."""
self._asset_uploader.clear_cache()
with _userinfo_cache_lock:
_userinfo_cache.clear()
logger.info("All caches cleared")

def _apply_userinfo(self, result: dict[str, Any]) -> None:
sub = result.get("sub")
email = result.get("email")
if sub and email:
tracer.set_user_info(client_id=sub, email=email)

# OIDC userinfo returns `role` as a list when there are
# multiple, or a bare string when there is exactly one.
# A substring check like `"Admin" in result.get("role", [])`
# matches `"Administrator"`, `"SuperAdmin"`, etc., so do an
# explicit equality check against a normalized list.
roles_raw = result.get("role", [])
if isinstance(roles_raw, str):
roles = [roles_raw]
elif isinstance(roles_raw, list):
roles = roles_raw
else:
roles = []

if "Admin" not in roles:
logger.debug("User is not an admin, not enabling beta features")
return

logger.debug("User is an admin, enabling beta features")
rapidata_config.enableBetaFeatures = True

def _check_beta_features(self):
"""Enable beta features for the client."""
with optional_api_call("check beta features"):
with tracer.start_as_current_span("RapidataClient.check_beta_features"):
cache_key: tuple[str, str] | None = (
(self._environment, self._client_id)
if self._client_id is not None
else None
)
if cache_key is not None:
with _userinfo_cache_lock:
entry = _userinfo_cache.get(cache_key)
if entry is not None and entry.expires_at > time.monotonic():
cached_result = entry.result
else:
cached_result = None
if cached_result is not None:
logger.debug("Userinfo cache hit for %s", cache_key)
self._apply_userinfo(cached_result)
return

result: dict[str, Any] = json.loads(
self._openapi_service.api_client.call_api(
"GET",
Expand All @@ -178,30 +247,19 @@ def _check_beta_features(self):
)
logger.debug("Userinfo: %s", result)

client_id = result.get("sub")
email = result.get("email")
if client_id and email:
tracer.set_user_info(client_id=client_id, email=email)

# OIDC userinfo returns `role` as a list when there are
# multiple, or a bare string when there is exactly one.
# A substring check like `"Admin" in result.get("role", [])`
# matches `"Administrator"`, `"SuperAdmin"`, etc., so do an
# explicit equality check against a normalized list.
roles_raw = result.get("role", [])
if isinstance(roles_raw, str):
roles = [roles_raw]
elif isinstance(roles_raw, list):
roles = roles_raw
else:
roles = []

if "Admin" not in roles:
logger.debug("User is not an admin, not enabling beta features")
return

logger.debug("User is an admin, enabling beta features")
rapidata_config.enableBetaFeatures = True
effective_key = cache_key
if effective_key is None:
sub = result.get("sub")
if sub:
effective_key = (self._environment, sub)
if effective_key is not None:
with _userinfo_cache_lock:
_userinfo_cache[effective_key] = _UserInfoCacheEntry(
result=result,
expires_at=time.monotonic() + _USERINFO_CACHE_TTL_SECONDS,
)

self._apply_userinfo(result)

def _check_version(self):
with optional_api_call("version check"):
Expand Down
Loading