diff --git a/src/rapidata/rapidata_client/rapidata_client.py b/src/rapidata/rapidata_client/rapidata_client.py index f7e3e535f..ad7d9f84c 100644 --- a/src/rapidata/rapidata_client/rapidata_client.py +++ b/src/rapidata/rapidata_client/rapidata_client.py @@ -1,5 +1,10 @@ +from __future__ import annotations + import json import os +import threading +import time +from dataclasses import dataclass from typing import Any import requests from packaging import version @@ -37,6 +42,19 @@ ) +# Cache userinfo process-wide so request bursts that spin up many short-lived +# RapidataClient instances don't hammer identity-service with redundant calls. +@dataclass +class _UserInfoCacheEntry: + result: dict[str, Any] + expires_at: float + + +_USERINFO_CACHE_TTL_SECONDS = 24 * 60 * 60 +_userinfo_cache: dict[tuple[str, str], _UserInfoCacheEntry] = {} +_userinfo_cache_lock = threading.Lock() + + class RapidataClient: """The Rapidata client is the main entry point for interacting with the Rapidata API. It allows you to create orders and validation sets.""" @@ -104,6 +122,9 @@ def __init__( if environment is None: environment = os.environ.get("RAPIDATA_ENVIRONMENT") or "rapidata.ai" + self._client_id = client_id + self._environment = environment + with tracer.start_as_current_span("RapidataClient.__init__"): logger.debug("Checking version") self._check_version() @@ -156,17 +177,65 @@ def reset_credentials(self): """Reset the credentials saved in the configuration file for the current environment.""" logger.info("Resetting credentials") self._openapi_service.reset_credentials() + if self._client_id is not None: + with _userinfo_cache_lock: + _userinfo_cache.pop((self._environment, self._client_id), None) logger.info("Credentials reset") def clear_all_caches(self): """Clear all caches for the client.""" self._asset_uploader.clear_cache() + with _userinfo_cache_lock: + _userinfo_cache.clear() logger.info("All caches cleared") + def _apply_userinfo(self, result: dict[str, Any]) -> None: + sub = result.get("sub") + email = result.get("email") + if sub and email: + tracer.set_user_info(client_id=sub, email=email) + + # OIDC userinfo returns `role` as a list when there are + # multiple, or a bare string when there is exactly one. + # A substring check like `"Admin" in result.get("role", [])` + # matches `"Administrator"`, `"SuperAdmin"`, etc., so do an + # explicit equality check against a normalized list. + roles_raw = result.get("role", []) + if isinstance(roles_raw, str): + roles = [roles_raw] + elif isinstance(roles_raw, list): + roles = roles_raw + else: + roles = [] + + if "Admin" not in roles: + logger.debug("User is not an admin, not enabling beta features") + return + + logger.debug("User is an admin, enabling beta features") + rapidata_config.enableBetaFeatures = True + def _check_beta_features(self): """Enable beta features for the client.""" with optional_api_call("check beta features"): with tracer.start_as_current_span("RapidataClient.check_beta_features"): + cache_key: tuple[str, str] | None = ( + (self._environment, self._client_id) + if self._client_id is not None + else None + ) + if cache_key is not None: + with _userinfo_cache_lock: + entry = _userinfo_cache.get(cache_key) + if entry is not None and entry.expires_at > time.monotonic(): + cached_result = entry.result + else: + cached_result = None + if cached_result is not None: + logger.debug("Userinfo cache hit for %s", cache_key) + self._apply_userinfo(cached_result) + return + result: dict[str, Any] = json.loads( self._openapi_service.api_client.call_api( "GET", @@ -178,30 +247,19 @@ def _check_beta_features(self): ) logger.debug("Userinfo: %s", result) - client_id = result.get("sub") - email = result.get("email") - if client_id and email: - tracer.set_user_info(client_id=client_id, email=email) - - # OIDC userinfo returns `role` as a list when there are - # multiple, or a bare string when there is exactly one. - # A substring check like `"Admin" in result.get("role", [])` - # matches `"Administrator"`, `"SuperAdmin"`, etc., so do an - # explicit equality check against a normalized list. - roles_raw = result.get("role", []) - if isinstance(roles_raw, str): - roles = [roles_raw] - elif isinstance(roles_raw, list): - roles = roles_raw - else: - roles = [] - - if "Admin" not in roles: - logger.debug("User is not an admin, not enabling beta features") - return - - logger.debug("User is an admin, enabling beta features") - rapidata_config.enableBetaFeatures = True + effective_key = cache_key + if effective_key is None: + sub = result.get("sub") + if sub: + effective_key = (self._environment, sub) + if effective_key is not None: + with _userinfo_cache_lock: + _userinfo_cache[effective_key] = _UserInfoCacheEntry( + result=result, + expires_at=time.monotonic() + _USERINFO_CACHE_TTL_SECONDS, + ) + + self._apply_userinfo(result) def _check_version(self): with optional_api_call("version check"):