From 88b9436b0adff483375c9f8be2e6ed564964613a Mon Sep 17 00:00:00 2001 From: Christian Leopoldseder Date: Mon, 13 Apr 2026 07:53:52 -0700 Subject: [PATCH] feat: Add Vertex Dataset input and output options for batch jobs PiperOrigin-RevId: 898998803 --- google/genai/_transformers.py | 13 +- google/genai/batches.py | 76 ++++++++++ .../test_create_with_vertex_dataset.py | 130 ++++++++++++++++++ .../genai/tests/transformers/test_t_batch.py | 23 +++- google/genai/types.py | 52 +++++++ 5 files changed, 289 insertions(+), 5 deletions(-) create mode 100644 google/genai/tests/batches/test_create_with_vertex_dataset.py diff --git a/google/genai/_transformers.py b/google/genai/_transformers.py index b9867d3e6..27d0e628b 100644 --- a/google/genai/_transformers.py +++ b/google/genai/_transformers.py @@ -1012,7 +1012,11 @@ def t_batch_job_source( src = types.BatchJobSource(**src) if is_duck_type_of(src, types.BatchJobSource): vertex_sources = sum( - [src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr] + [ + src.gcs_uri is not None, # type: ignore[union-attr] + src.bigquery_uri is not None, # type: ignore[union-attr] + src.vertex_dataset_name is not None, # type: ignore[union-attr] + ] ) mldev_sources = sum([ src.inlined_requests is not None, # type: ignore[union-attr] @@ -1021,7 +1025,7 @@ def t_batch_job_source( if client.vertexai: if mldev_sources or vertex_sources != 1: raise ValueError( - 'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other ' + 'Exactly one of `gcs_uri`, `bigquery_uri`, or `vertex_dataset_name` must be set, other ' 'sources are not supported in Vertex AI.' ) else: @@ -1046,6 +1050,11 @@ def t_batch_job_source( format='bigquery', bigquery_uri=src, ) + elif re.match(r'^projects/[^/]+/locations/[^/]+/datasets/[^/]+$', src): + return types.BatchJobSource( + format='vertex-dataset', + vertex_dataset_name=src, + ) elif src.startswith('files/'): return types.BatchJobSource( file_name=src, diff --git a/google/genai/batches.py b/google/genai/batches.py index 9c82e5b1d..7d122f086 100644 --- a/google/genai/batches.py +++ b/google/genai/batches.py @@ -130,6 +130,15 @@ def _BatchJobDestination_from_vertex( getv(from_object, ['bigqueryDestination', 'outputUri']), ) + if getv(from_object, ['vertexMultimodalDatasetDestination']) is not None: + setv( + to_object, + ['vertex_dataset'], + _VertexMultimodalDatasetDestination_from_vertex( + getv(from_object, ['vertexMultimodalDatasetDestination']), to_object + ), + ) + return to_object @@ -169,6 +178,15 @@ def _BatchJobDestination_to_vertex( ' Vertex AI.' ) + if getv(from_object, ['vertex_dataset']) is not None: + setv( + to_object, + ['vertexMultimodalDatasetDestination'], + _VertexMultimodalDatasetDestination_to_vertex( + getv(from_object, ['vertex_dataset']), to_object + ), + ) + return to_object @@ -190,6 +208,16 @@ def _BatchJobSource_from_vertex( getv(from_object, ['bigquerySource', 'inputUri']), ) + if ( + getv(from_object, ['vertexMultimodalDatasetSource', 'datasetName']) + is not None + ): + setv( + to_object, + ['vertex_dataset_name'], + getv(from_object, ['vertexMultimodalDatasetSource', 'datasetName']), + ) + return to_object @@ -221,6 +249,11 @@ def _BatchJobSource_to_mldev( ], ) + if getv(from_object, ['vertex_dataset_name']) is not None: + raise ValueError( + 'vertex_dataset_name parameter is not supported in Gemini API.' + ) + return to_object @@ -250,6 +283,13 @@ def _BatchJobSource_to_vertex( 'inlined_requests parameter is not supported in Vertex AI.' ) + if getv(from_object, ['vertex_dataset_name']) is not None: + setv( + to_object, + ['vertexMultimodalDatasetSource', 'datasetName'], + getv(from_object, ['vertex_dataset_name']), + ) + return to_object @@ -1603,6 +1643,42 @@ def _Tool_to_mldev( return to_object +def _VertexMultimodalDatasetDestination_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['bigqueryDestination', 'outputUri']) is not None: + setv( + to_object, + ['bigquery_destination'], + getv(from_object, ['bigqueryDestination', 'outputUri']), + ) + + if getv(from_object, ['displayName']) is not None: + setv(to_object, ['display_name'], getv(from_object, ['displayName'])) + + return to_object + + +def _VertexMultimodalDatasetDestination_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['bigquery_destination']) is not None: + setv( + to_object, + ['bigqueryDestination', 'outputUri'], + getv(from_object, ['bigquery_destination']), + ) + + if getv(from_object, ['display_name']) is not None: + setv(to_object, ['displayName'], getv(from_object, ['display_name'])) + + return to_object + + class Batches(_api_module.BaseModule): def _create( diff --git a/google/genai/tests/batches/test_create_with_vertex_dataset.py b/google/genai/tests/batches/test_create_with_vertex_dataset.py new file mode 100644 index 000000000..63e274c63 --- /dev/null +++ b/google/genai/tests/batches/test_create_with_vertex_dataset.py @@ -0,0 +1,130 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +"""Tests for batches.create() with Vertex dataset source.""" + +import re + +import pytest + +from .. import pytest_helper +from ... import types + + +_GEMINI_MODEL = 'gemini-2.5-flash' +_GEMINI_MODEL_FULL_NAME = 'publishers/google/models/gemini-2.5-flash' +_OUTPUT_VERTEX_DATASET_DISPLAY_NAME = 'test_batch_output' +_VERTEX_DATASET_INPUT_NAME = ( + 'projects/vertex-sdk-dev/locations/us-central1/datasets/7857316250517504000' +) +_DISPLAY_NAME = 'test_batch' + +_BQ_OUTPUT_PREFIX = ( + 'bq://vertex-sdk-dev.unified_genai_tests_batches.generate_content_output' +) +_VERTEX_DATASET_DESTINATION = types.VertexMultimodalDatasetDestination( + bigquery_destination=_BQ_OUTPUT_PREFIX, + display_name=_OUTPUT_VERTEX_DATASET_DISPLAY_NAME, +) + + +# All tests will be run for both Vertex and MLDev. +test_table: list[pytest_helper.TestTableItem] = [ + pytest_helper.TestTableItem( + name='test_union_generate_content_with_vertex_dataset_name', + parameters=types._CreateBatchJobParameters( + model=_GEMINI_MODEL_FULL_NAME, + src=_VERTEX_DATASET_INPUT_NAME, + config={ + 'display_name': _DISPLAY_NAME, + 'dest': { + 'vertex_dataset': _VERTEX_DATASET_DESTINATION, + 'format': 'vertex-dataset', + }, + }, + ), + exception_if_mldev='not supported in Gemini API', + has_union=True, + ), + pytest_helper.TestTableItem( + name='test_generate_content_with_vertex_dataset_source', + parameters=types._CreateBatchJobParameters( + model=_GEMINI_MODEL_FULL_NAME, + src=types.BatchJobSource( + vertex_dataset_name=_VERTEX_DATASET_INPUT_NAME, + format='vertex-dataset', + ), + config={ + 'display_name': _DISPLAY_NAME, + 'dest': { + 'vertex_dataset': _VERTEX_DATASET_DESTINATION, + 'format': 'vertex-dataset', + }, + }, + ), + exception_if_mldev='one of', + ), + pytest_helper.TestTableItem( + name='test_generate_content_with_vertex_dataset_source_dict', + parameters=types._CreateBatchJobParameters( + model=_GEMINI_MODEL_FULL_NAME, + src={ + 'vertex_dataset_name': _VERTEX_DATASET_INPUT_NAME, + 'format': 'vertex-dataset', + }, + config={ + 'display_name': _DISPLAY_NAME, + 'dest': { + 'vertex_dataset': _VERTEX_DATASET_DESTINATION, + 'format': 'vertex-dataset', + }, + }, + ), + exception_if_mldev='one of', + ), +] + +pytestmark = [ + pytest.mark.usefixtures('mock_timestamped_unique_name'), + pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method='batches.create', + test_table=test_table, + ), +] + + +@pytest.mark.asyncio +async def test_async_create(client): + with pytest_helper.exception_if_mldev(client, ValueError): + batch_job = await client.aio.batches.create( + model=_GEMINI_MODEL, + src=_VERTEX_DATASET_INPUT_NAME, + config={ + 'dest': { + 'vertex_dataset': _VERTEX_DATASET_DESTINATION, + 'format': 'vertex-dataset', + }, + }, + ) + + assert batch_job.name.startswith('projects/') + assert ( + batch_job.model == _GEMINI_MODEL_FULL_NAME + ) # Converted to Vertex full name. + assert batch_job.src.vertex_dataset_name == _VERTEX_DATASET_INPUT_NAME + assert batch_job.src.format == 'vertex-dataset' diff --git a/google/genai/tests/transformers/test_t_batch.py b/google/genai/tests/transformers/test_t_batch.py index f355690ee..991f9b6a3 100644 --- a/google/genai/tests/transformers/test_t_batch.py +++ b/google/genai/tests/transformers/test_t_batch.py @@ -172,19 +172,36 @@ def test_batch_job_source_vertexai_valid_bigquery(self, vertex_client): result = t.t_batch_job_source(vertex_client, src_obj) assert result is src_obj - def test_batch_job_source_vertexai_valid_both(self, vertex_client): + def test_batch_job_source_vertexai_valid_all(self, vertex_client): src_obj = types.BatchJobSource( gcs_uri=['gs://vertex-bucket/data.jsonl'], bigquery_uri='bq://project.dataset.table', + vertex_dataset_name='projects/123/locations/us-central1/datasets/456', ) - with pytest.raises(ValueError, match='`gcs_uri` or `bigquery_uri`'): + with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'): + t.t_batch_job_source(vertex_client, src_obj) + + def test_batch_job_source_vertexai_valid_gcs_and_bigquery(self, vertex_client): + src_obj = types.BatchJobSource( + gcs_uri=['gs://vertex-bucket/data.jsonl'], + bigquery_uri='bq://project.dataset.table', + ) + with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'): + t.t_batch_job_source(vertex_client, src_obj) + + def test_batch_job_source_vertexai_valid_bigquery_and_vertex_dataset(self, vertex_client): + src_obj = types.BatchJobSource( + bigquery_uri='bq://project.dataset.table', + vertex_dataset_name='projects/123/locations/us-central1/datasets/456', + ) + with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'): t.t_batch_job_source(vertex_client, src_obj) def test_batch_job_source_vertexai_invalid_neither_set(self, vertex_client): src_obj = types.BatchJobSource( file_name='files/data.csv' ) - with pytest.raises(ValueError, match='`gcs_uri` or `bigquery_uri`'): + with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'): t.t_batch_job_source(vertex_client, src_obj) diff --git a/google/genai/types.py b/google/genai/types.py index 20c4f1c0e..3542f9365 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -16148,6 +16148,11 @@ class BatchJobSource(_common.BaseModel): description="""The Gemini Developer API's inlined input data to run batch job. """, ) + vertex_dataset_name: Optional[str] = Field( + default=None, + description="""This field is experimental and may change in future versions. The Vertex AI dataset resource name to use as input. Must be of type multimodal. + """, + ) class BatchJobSourceDict(TypedDict, total=False): @@ -16175,10 +16180,48 @@ class BatchJobSourceDict(TypedDict, total=False): """The Gemini Developer API's inlined input data to run batch job. """ + vertex_dataset_name: Optional[str] + """This field is experimental and may change in future versions. The Vertex AI dataset resource name to use as input. Must be of type multimodal. + """ + BatchJobSourceOrDict = Union[BatchJobSource, BatchJobSourceDict] +class VertexMultimodalDatasetDestination(_common.BaseModel): + """This class is experimental and may change in future versions. + + The specification for an output Vertex AI multimodal dataset. + """ + + bigquery_destination: Optional[str] = Field( + default=None, + description="""The BigQuery destination for the multimodal dataset.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""The display name of the multimodal dataset.""", + ) + + +class VertexMultimodalDatasetDestinationDict(TypedDict, total=False): + """This class is experimental and may change in future versions. + + The specification for an output Vertex AI multimodal dataset. + """ + + bigquery_destination: Optional[str] + """The BigQuery destination for the multimodal dataset.""" + + display_name: Optional[str] + """The display name of the multimodal dataset.""" + + +VertexMultimodalDatasetDestinationOrDict = Union[ + VertexMultimodalDatasetDestination, VertexMultimodalDatasetDestinationDict +] + + class JobError(_common.BaseModel): """Job error.""" @@ -16356,6 +16399,11 @@ class BatchJobDestination(_common.BaseModel): the input requests. """, ) + vertex_dataset: Optional[VertexMultimodalDatasetDestination] = Field( + default=None, + description="""This field is experimental and may change in future versions. The Vertex AI dataset destination. + """, + ) class BatchJobDestinationDict(TypedDict, total=False): @@ -16396,6 +16444,10 @@ class BatchJobDestinationDict(TypedDict, total=False): the input requests. """ + vertex_dataset: Optional[VertexMultimodalDatasetDestinationDict] + """This field is experimental and may change in future versions. The Vertex AI dataset destination. + """ + BatchJobDestinationOrDict = Union[BatchJobDestination, BatchJobDestinationDict]