diff --git a/src/rapidata/__init__.py b/src/rapidata/__init__.py index 5325fbce..70d22023 100644 --- a/src/rapidata/__init__.py +++ b/src/rapidata/__init__.py @@ -2,6 +2,7 @@ from .rapidata_client import ( RapidataClient, + OpenAPIService, RapidataAudience, RapidataAudienceManager, RapidataOrderManager, diff --git a/src/rapidata/rapidata_client/__init__.py b/src/rapidata/rapidata_client/__init__.py index c66becf9..f3595582 100644 --- a/src/rapidata/rapidata_client/__init__.py +++ b/src/rapidata/rapidata_client/__init__.py @@ -1,4 +1,5 @@ from .rapidata_client import RapidataClient +from rapidata.service.openapi_service import OpenAPIService from .audience import ( RapidataAudience, RapidataAudienceManager, diff --git a/src/rapidata/rapidata_client/rapidata_client.py b/src/rapidata/rapidata_client/rapidata_client.py index 407a2b58..32ab42c3 100644 --- a/src/rapidata/rapidata_client/rapidata_client.py +++ b/src/rapidata/rapidata_client/rapidata_client.py @@ -44,6 +44,7 @@ def __init__( cert_path: str | None = None, token: dict | None = None, leeway: int = 60, + openapi_service: OpenAPIService | None = None, ): """Initialize the RapidataClient. If both the client_id and client_secret are None, it will try using your credentials under "~/.config/rapidata/credentials.json". If this is not successful, it will open a browser window and ask you to log in, then save your new credentials in said json file. @@ -56,6 +57,7 @@ def __init__( cert_path (str, optional): An optional path to a certificate file useful for development. token (dict, optional): If you already have a token that the client should use for authentication. Important, if set, this needs to be the complete token object containing the access token, token type and expiration time. leeway (int, optional): An optional leeway to use to determine if a token is expired. Defaults to 60 seconds. + openapi_service (OpenAPIService, optional): An existing OpenAPIService instance to reuse. When provided, the client will share the underlying connection pool and authentication session instead of creating new ones. This is useful when creating multiple RapidataClient instances to avoid opening too many connections. Attributes: order (RapidataOrderManager): The RapidataOrderManager instance. @@ -72,19 +74,26 @@ def __init__( with tracer.start_as_current_span("RapidataClient.__init__"): logger.debug("Checking version") self._check_version() - if environment != "rapidata.ai": - rapidata_config.logging.enable_otlp = False - - logger.debug("Initializing OpenAPIService") - self._openapi_service = OpenAPIService( - client_id=client_id, - client_secret=client_secret, - environment=environment, - oauth_scope=oauth_scope, - cert_path=cert_path, - token=token, - leeway=leeway, - ) + + if openapi_service is not None: + logger.debug("Reusing existing OpenAPIService") + self._openapi_service = openapi_service + if openapi_service.environment != "rapidata.ai": + rapidata_config.logging.enable_otlp = False + else: + if environment != "rapidata.ai": + rapidata_config.logging.enable_otlp = False + + logger.debug("Initializing OpenAPIService") + self._openapi_service = OpenAPIService( + client_id=client_id, + client_secret=client_secret, + environment=environment, + oauth_scope=oauth_scope, + cert_path=cert_path, + token=token, + leeway=leeway, + ) self._asset_uploader = AssetUploader(openapi_service=self._openapi_service) @@ -117,6 +126,11 @@ def __init__( self._check_beta_features() # can't be in the trace for some reason + @property + def openapi_service(self) -> OpenAPIService: + """The OpenAPIService instance used by this client. Can be passed to other RapidataClient instances to share the connection pool.""" + return self._openapi_service + def reset_credentials(self): """Reset the credentials saved in the configuration file for the current environment.""" logger.info("Resetting credentials")