diff --git a/agentplatform/_genai/_operations_utils.py b/agentplatform/_genai/_operations_utils.py new file mode 100644 index 0000000000..12c54455aa --- /dev/null +++ b/agentplatform/_genai/_operations_utils.py @@ -0,0 +1,94 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utility functions for Operations.""" + +import asyncio +import datetime +import time +from typing import Any, Awaitable, Callable + + +def await_operation( + *, + operation_name: str, + get_operation_fn: Callable[..., Any], + poll_interval: datetime.timedelta | float = 10.0, + timeout_seconds: float = 300.0, +) -> Any: + """Waits for a long running operation to complete. + + Args: + operation_name (str): Required. The name of the operation. + get_operation_fn (Callable): Required. Function to get the operation + status. + poll_interval (datetime.timedelta | float): The interval between polls. + timeout_seconds (float): The maximum wait duration in seconds. + + Returns: + Any: The completed operation. + """ + if isinstance(poll_interval, datetime.timedelta): + poll_seconds = poll_interval.total_seconds() + else: + poll_seconds = float(poll_interval) + + start_time = time.time() + operation = get_operation_fn(operation_name=operation_name) + while not operation.done: + if (time.time() - start_time) > timeout_seconds: + raise TimeoutError( + f"Operation {operation_name} did not complete within the timeout " + f"of {timeout_seconds} seconds." + ) + time.sleep(poll_seconds) + operation = get_operation_fn(operation_name=operation.name) + return operation + + +async def await_operation_async( + *, + operation_name: str, + get_operation_fn: Callable[..., Awaitable[Any]], + poll_interval: datetime.timedelta | float = 10.0, + timeout_seconds: float = 300.0, +) -> Any: + """Waits for a long running operation to complete asynchronously. + + Args: + operation_name (str): Required. The name of the operation. + get_operation_fn (Callable): Required. Async function to get the operation + status. + poll_interval (datetime.timedelta | float): The interval between polls. + timeout_seconds (float): The maximum wait duration in seconds. + + Returns: + Any: The completed operation. + """ + if isinstance(poll_interval, datetime.timedelta): + poll_seconds = poll_interval.total_seconds() + else: + poll_seconds = float(poll_interval) + + start_time = time.time() + operation = await get_operation_fn(operation_name=operation_name) + while not operation.done: + if (time.time() - start_time) > timeout_seconds: + raise TimeoutError( + f"Operation {operation_name} did not complete within the timeout " + f"of {timeout_seconds} seconds." + ) + await asyncio.sleep(poll_seconds) + operation = await get_operation_fn(operation_name=operation.name) + return operation diff --git a/agentplatform/_genai/_skills_utils.py b/agentplatform/_genai/_skills_utils.py index d46a3e0ea9..0258c9fcc7 100644 --- a/agentplatform/_genai/_skills_utils.py +++ b/agentplatform/_genai/_skills_utils.py @@ -14,14 +14,10 @@ # """Utility functions for Skills.""" -import asyncio import base64 -import datetime import io import os import pathlib -import time -from typing import Any, Awaitable, Callable import zipfile @@ -71,77 +67,3 @@ def get_zipped_filesystem_payload(directory_path: pathlib.Path | str) -> str: """ zip_bytes = zip_directory(directory_path) return base64.b64encode(zip_bytes).decode("utf-8") - - -def await_operation( - *, - operation_name: str, - get_operation_fn: Callable[..., Any], - poll_interval: datetime.timedelta | float = 10.0, - timeout_seconds: float = 300.0, -) -> Any: - """Waits for a long running operation to complete. - - Args: - operation_name (str): Required. The name of the operation. - get_operation_fn (Callable): Required. Function to get the operation - status. - poll_interval (datetime.timedelta | float): The interval between polls. - timeout_seconds (float): The maximum wait duration in seconds. - - Returns: - Any: The completed operation. - """ - if isinstance(poll_interval, datetime.timedelta): - poll_seconds = poll_interval.total_seconds() - else: - poll_seconds = float(poll_interval) - - start_time = time.time() - operation = get_operation_fn(operation_name=operation_name) - while not operation.done: - if (time.time() - start_time) > timeout_seconds: - raise TimeoutError( - f"Operation {operation_name} did not complete within the timeout " - f"of {timeout_seconds} seconds." - ) - time.sleep(poll_seconds) - operation = get_operation_fn(operation_name=operation.name) - return operation - - -async def await_operation_async( - *, - operation_name: str, - get_operation_fn: Callable[..., Awaitable[Any]], - poll_interval: datetime.timedelta | float = 10.0, - timeout_seconds: float = 300.0, -) -> Any: - """Waits for a long running operation to complete asynchronously. - - Args: - operation_name (str): Required. The name of the operation. - get_operation_fn (Callable): Required. Async function to get the operation - status. - poll_interval (datetime.timedelta | float): The interval between polls. - timeout_seconds (float): The maximum wait duration in seconds. - - Returns: - Any: The completed operation. - """ - if isinstance(poll_interval, datetime.timedelta): - poll_seconds = poll_interval.total_seconds() - else: - poll_seconds = float(poll_interval) - - start_time = time.time() - operation = await get_operation_fn(operation_name=operation_name) - while not operation.done: - if (time.time() - start_time) > timeout_seconds: - raise TimeoutError( - f"Operation {operation_name} did not complete within the timeout " - f"of {timeout_seconds} seconds." - ) - await asyncio.sleep(poll_seconds) - operation = await get_operation_fn(operation_name=operation.name) - return operation diff --git a/agentplatform/_genai/rag.py b/agentplatform/_genai/rag.py index c80642279d..0f43dd54c5 100644 --- a/agentplatform/_genai/rag.py +++ b/agentplatform/_genai/rag.py @@ -26,6 +26,7 @@ from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv +from . import _operations_utils from . import types logger = logging.getLogger("agentplatform_genai.rag") @@ -48,6 +49,33 @@ def _AskContextsRequestParameters_to_vertex( return to_object +def _CorpusOperation_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["done"]) is not None: + setv(to_object, ["done"], getv(from_object, ["done"])) + + if getv(from_object, ["error"]) is not None: + setv(to_object, ["error"], getv(from_object, ["error"])) + + if getv(from_object, ["response"]) is not None: + setv( + to_object, + ["response"], + _RagCorpus_from_vertex(getv(from_object, ["response"]), to_object), + ) + + return to_object + + def _CreateRagCorpusRequestParameters_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -97,6 +125,19 @@ def _DeleteRagFileRequestParameters_to_vertex( return to_object +def _GetCorpusOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + def _GetRagConfigRequestParameters_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -565,6 +606,77 @@ def _create_corpus( self._api_client._verify_response(return_value) return return_value + def _get_corpus_operation( + self, + *, + operation_name: str, + config: Optional[types.GetCorpusOperationConfigOrDict] = None, + ) -> types.CorpusOperation: + parameter_model = types._GetCorpusOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode." + ) + else: + request_dict = _GetCorpusOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CorpusOperation_from_vertex(response_dict) + + return_value = types.CorpusOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + def get_corpus( self, *, name: str, config: Optional[types.GetRagCorpusConfigOrDict] = None ) -> types.RagCorpus: @@ -1293,6 +1405,34 @@ def retrieve_contexts( self._api_client._verify_response(return_value) return return_value + def create_corpus( + self, + *, + rag_corpus: types.RagCorpusOrDict, + config: Optional[types.CreateRagCorpusConfigOrDict] = None, + ) -> types.RagCorpus: + """ + Creates a new Rag Corpus and waits for completion. + + Args: + rag_corpus: The RagCorpus to create. + config: The configuration to use for the RagCorpus. + + Returns: + The created RagCorpus. + """ + operation = self._create_corpus(rag_corpus=rag_corpus, config=config) + + operation = _operations_utils.await_operation( + operation_name=operation.name, + get_operation_fn=self._get_corpus_operation, + ) + + if operation.error: + raise RuntimeError(f"Failed to create RagCorpus: {operation.error}") + + return self.get_corpus(name=operation.response.name) + class AsyncRag(_api_module.BaseModule): @@ -1446,6 +1586,79 @@ async def _create_corpus( self._api_client._verify_response(return_value) return return_value + async def _get_corpus_operation( + self, + *, + operation_name: str, + config: Optional[types.GetCorpusOperationConfigOrDict] = None, + ) -> types.CorpusOperation: + parameter_model = types._GetCorpusOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode." + ) + else: + request_dict = _GetCorpusOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CorpusOperation_from_vertex(response_dict) + + return_value = types.CorpusOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + async def get_corpus( self, *, name: str, config: Optional[types.GetRagCorpusConfigOrDict] = None ) -> types.RagCorpus: @@ -2193,3 +2406,31 @@ async def retrieve_contexts( self._api_client._verify_response(return_value) return return_value + + async def create_corpus( + self, + *, + rag_corpus: types.RagCorpusOrDict, + config: Optional[types.CreateRagCorpusConfigOrDict] = None, + ) -> types.RagCorpus: + """ + Creates a new Rag Corpus and waits for completion asynchronously. + + Args: + rag_corpus: The RagCorpus to create. + config: The configuration to use for the RagCorpus. + + Returns: + The created RagCorpus. + """ + operation = await self._create_corpus(rag_corpus=rag_corpus, config=config) + + operation = await _operations_utils.await_operation_async( + operation_name=operation.name, + get_operation_fn=self._get_corpus_operation, + ) + + if operation.error: + raise RuntimeError(f"Failed to create RagCorpus: {operation.error}") + + return await self.get_corpus(name=operation.response.name) diff --git a/agentplatform/_genai/skills.py b/agentplatform/_genai/skills.py index f627756661..4f915a4d0d 100644 --- a/agentplatform/_genai/skills.py +++ b/agentplatform/_genai/skills.py @@ -30,6 +30,7 @@ from google.genai._common import set_value_by_path as setv from google.genai.pagers import AsyncPager, Pager +from . import _operations_utils from . import _skills_utils from . import types @@ -784,7 +785,7 @@ def create( ) if config.wait_for_completion: - operation = _skills_utils.await_operation( + operation = _operations_utils.await_operation( operation_name=operation.name, get_operation_fn=self._get_skill_operation, ) @@ -878,7 +879,7 @@ def update( ) if config.wait_for_completion: - operation = _skills_utils.await_operation( + operation = _operations_utils.await_operation( operation_name=operation.name, get_operation_fn=self._get_skill_operation, ) @@ -920,7 +921,7 @@ def delete( operation = self._delete(name=name, config=config) if config.wait_for_completion: - operation = _skills_utils.await_operation( + operation = _operations_utils.await_operation( operation_name=operation.name, get_operation_fn=self._get_skill_operation, ) @@ -1539,7 +1540,7 @@ async def create( ) if config.wait_for_completion: - operation = await _skills_utils.await_operation_async( + operation = await _operations_utils.await_operation_async( operation_name=operation.name, get_operation_fn=self._get_skill_operation, ) @@ -1634,7 +1635,7 @@ async def update( ) if config.wait_for_completion: - operation = await _skills_utils.await_operation_async( + operation = await _operations_utils.await_operation_async( operation_name=operation.name, get_operation_fn=self._get_skill_operation, ) @@ -1676,7 +1677,7 @@ async def delete( operation = await self._delete(name=name, config=config) if config.wait_for_completion: - operation = await _skills_utils.await_operation_async( + operation = await _operations_utils.await_operation_async( operation_name=operation.name, get_operation_fn=self._get_skill_operation, ) diff --git a/agentplatform/_genai/types/__init__.py b/agentplatform/_genai/types/__init__.py index ae9934cb49..036531599c 100644 --- a/agentplatform/_genai/types/__init__.py +++ b/agentplatform/_genai/types/__init__.py @@ -81,6 +81,7 @@ from .common import _GetAgentEngineSessionOperationParameters from .common import _GetAgentEngineSessionRequestParameters from .common import _GetAgentEngineTaskRequestParameters +from .common import _GetCorpusOperationParameters from .common import _GetCustomJobParameters from .common import _GetCustomJobParameters from .common import _GetDatasetOperationParameters @@ -292,6 +293,9 @@ from .common import ContentMapContentsOrDict from .common import ContentMapDict from .common import ContentMapOrDict +from .common import CorpusOperation +from .common import CorpusOperationDict +from .common import CorpusOperationOrDict from .common import CorpusStatus from .common import CorpusStatusDict from .common import CorpusStatusOrDict @@ -659,6 +663,9 @@ from .common import GetAgentEngineTaskConfig from .common import GetAgentEngineTaskConfigDict from .common import GetAgentEngineTaskConfigOrDict +from .common import GetCorpusOperationConfig +from .common import GetCorpusOperationConfigDict +from .common import GetCorpusOperationConfigOrDict from .common import GetDatasetOperationConfig from .common import GetDatasetOperationConfigDict from .common import GetDatasetOperationConfigOrDict @@ -2598,6 +2605,12 @@ "CreateRagCorpusOperation", "CreateRagCorpusOperationDict", "CreateRagCorpusOperationOrDict", + "GetCorpusOperationConfig", + "GetCorpusOperationConfigDict", + "GetCorpusOperationConfigOrDict", + "CorpusOperation", + "CorpusOperationDict", + "CorpusOperationOrDict", "GetRagCorpusConfig", "GetRagCorpusConfigDict", "GetRagCorpusConfigOrDict", @@ -3310,6 +3323,7 @@ "_ListAgentEngineMemoryRevisionsRequestParameters", "_AskContextsRequestParameters", "_CreateRagCorpusRequestParameters", + "_GetCorpusOperationParameters", "_GetRagCorpusRequestParameters", "_ListRagCorporaRequestParameters", "_GetRagFileRequestParameters", diff --git a/agentplatform/_genai/types/common.py b/agentplatform/_genai/types/common.py index 235ec387bb..347b269a16 100644 --- a/agentplatform/_genai/types/common.py +++ b/agentplatform/_genai/types/common.py @@ -13185,6 +13185,98 @@ class CreateRagCorpusOperationDict(TypedDict, total=False): ] +class GetCorpusOperationConfig(_common.BaseModel): + """Config for getting a corpus operation.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetCorpusOperationConfigDict(TypedDict, total=False): + """Config for getting a corpus operation.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetCorpusOperationConfigOrDict = Union[ + GetCorpusOperationConfig, GetCorpusOperationConfigDict +] + + +class _GetCorpusOperationParameters(_common.BaseModel): + """Parameters for getting a corpus operation.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetCorpusOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetCorpusOperationParametersDict(TypedDict, total=False): + """Parameters for getting a corpus operation.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetCorpusOperationConfigDict] + """Used to override the default configuration.""" + + +_GetCorpusOperationParametersOrDict = Union[ + _GetCorpusOperationParameters, _GetCorpusOperationParametersDict +] + + +class CorpusOperation(_common.BaseModel): + """Operation that has a corpus as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[RagCorpus] = Field( + default=None, description="""The created Corpus.""" + ) + + +class CorpusOperationDict(TypedDict, total=False): + """Operation that has a corpus as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[RagCorpusDict] + """The created Corpus.""" + + +CorpusOperationOrDict = Union[CorpusOperation, CorpusOperationDict] + + class GetRagCorpusConfig(_common.BaseModel): """Config for getting a RAG corpus.""" diff --git a/tests/unit/agentplatform/genai/replays/test_rag_create_corpus.py b/tests/unit/agentplatform/genai/replays/test_rag_create_corpus.py index 8bd4d4e30d..f0f1e84649 100644 --- a/tests/unit/agentplatform/genai/replays/test_rag_create_corpus.py +++ b/tests/unit/agentplatform/genai/replays/test_rag_create_corpus.py @@ -38,6 +38,21 @@ def test_create_rag_corpus_private(client): assert isinstance(corpus_op, types.CreateRagCorpusOperation) +def test_create_rag_corpus(client): + + corpus_description = "My Test Corpus Description" + + corpus = client.rag.create_corpus( + rag_corpus=types.RagCorpus( + display_name="My Test Corpus", + description=corpus_description, + ), + ) + + assert isinstance(corpus, types.RagCorpus) + assert corpus.description == corpus_description + + pytest_plugins = ("pytest_asyncio",) @@ -52,3 +67,19 @@ async def test_create_rag_corpus_private_async(client): ) assert isinstance(corpus_op, types.CreateRagCorpusOperation) + + +@pytest.mark.asyncio +async def test_create_rag_corpus_async(client): + + corpus_description = "My Test Corpus Description" + + corpus = await client.aio.rag.create_corpus( + rag_corpus=types.RagCorpus( + display_name="My Test Corpus", + description=corpus_description, + ), + ) + + assert isinstance(corpus, types.RagCorpus) + assert corpus.description == corpus_description