Skip to content

Commit 26d24b6

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: Add Vertex Dataset input and output options for batch jobs
PiperOrigin-RevId: 898998803
1 parent b91bda5 commit 26d24b6

5 files changed

Lines changed: 276 additions & 5 deletions

File tree

google/genai/_transformers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,11 @@ def t_batch_job_source(
10121012
src = types.BatchJobSource(**src)
10131013
if is_duck_type_of(src, types.BatchJobSource):
10141014
vertex_sources = sum(
1015-
[src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr]
1015+
[
1016+
src.gcs_uri is not None,
1017+
src.bigquery_uri is not None,
1018+
src.vertex_dataset_name is not None,
1019+
] # type: ignore[union-attr]
10161020
)
10171021
mldev_sources = sum([
10181022
src.inlined_requests is not None, # type: ignore[union-attr]
@@ -1021,7 +1025,7 @@ def t_batch_job_source(
10211025
if client.vertexai:
10221026
if mldev_sources or vertex_sources != 1:
10231027
raise ValueError(
1024-
'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
1028+
'Exactly one of `gcs_uri`, `bigquery_uri`, or `vertex_dataset_name` must be set, other '
10251029
'sources are not supported in Vertex AI.'
10261030
)
10271031
else:
@@ -1046,6 +1050,11 @@ def t_batch_job_source(
10461050
format='bigquery',
10471051
bigquery_uri=src,
10481052
)
1053+
elif re.match(r'^projects/[^/]+/locations/[^/]+/datasets/[^/]+$', src):
1054+
return types.BatchJobSource(
1055+
format='vertex-dataset',
1056+
vertex_dataset_name=src,
1057+
)
10491058
elif src.startswith('files/'):
10501059
return types.BatchJobSource(
10511060
file_name=src,

google/genai/batches.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ def _BatchJobDestination_from_vertex(
130130
getv(from_object, ['bigqueryDestination', 'outputUri']),
131131
)
132132

133+
if getv(from_object, ['vertexMultimodalDatasetDestination']) is not None:
134+
setv(
135+
to_object,
136+
['vertex_dataset'],
137+
_VertexMultimodalDatasetDestination_from_vertex(
138+
getv(from_object, ['vertexMultimodalDatasetDestination']), to_object
139+
),
140+
)
141+
133142
return to_object
134143

135144

@@ -169,6 +178,15 @@ def _BatchJobDestination_to_vertex(
169178
' Vertex AI.'
170179
)
171180

181+
if getv(from_object, ['vertex_dataset']) is not None:
182+
setv(
183+
to_object,
184+
['vertexMultimodalDatasetDestination'],
185+
_VertexMultimodalDatasetDestination_to_vertex(
186+
getv(from_object, ['vertex_dataset']), to_object
187+
),
188+
)
189+
172190
return to_object
173191

174192

@@ -190,6 +208,16 @@ def _BatchJobSource_from_vertex(
190208
getv(from_object, ['bigquerySource', 'inputUri']),
191209
)
192210

211+
if (
212+
getv(from_object, ['vertexMultimodalDatasetSource', 'datasetName'])
213+
is not None
214+
):
215+
setv(
216+
to_object,
217+
['vertex_dataset_name'],
218+
getv(from_object, ['vertexMultimodalDatasetSource', 'datasetName']),
219+
)
220+
193221
return to_object
194222

195223

@@ -221,6 +249,11 @@ def _BatchJobSource_to_mldev(
221249
],
222250
)
223251

252+
if getv(from_object, ['vertex_dataset_name']) is not None:
253+
raise ValueError(
254+
'vertex_dataset_name parameter is not supported in Gemini API.'
255+
)
256+
224257
return to_object
225258

226259

@@ -250,6 +283,13 @@ def _BatchJobSource_to_vertex(
250283
'inlined_requests parameter is not supported in Vertex AI.'
251284
)
252285

286+
if getv(from_object, ['vertex_dataset_name']) is not None:
287+
setv(
288+
to_object,
289+
['vertexMultimodalDatasetSource', 'datasetName'],
290+
getv(from_object, ['vertex_dataset_name']),
291+
)
292+
253293
return to_object
254294

255295

@@ -1593,6 +1633,42 @@ def _Tool_to_mldev(
15931633
return to_object
15941634

15951635

1636+
def _VertexMultimodalDatasetDestination_from_vertex(
1637+
from_object: Union[dict[str, Any], object],
1638+
parent_object: Optional[dict[str, Any]] = None,
1639+
) -> dict[str, Any]:
1640+
to_object: dict[str, Any] = {}
1641+
if getv(from_object, ['bigqueryDestination', 'outputUri']) is not None:
1642+
setv(
1643+
to_object,
1644+
['bigquery_destination'],
1645+
getv(from_object, ['bigqueryDestination', 'outputUri']),
1646+
)
1647+
1648+
if getv(from_object, ['displayName']) is not None:
1649+
setv(to_object, ['display_name'], getv(from_object, ['displayName']))
1650+
1651+
return to_object
1652+
1653+
1654+
def _VertexMultimodalDatasetDestination_to_vertex(
1655+
from_object: Union[dict[str, Any], object],
1656+
parent_object: Optional[dict[str, Any]] = None,
1657+
) -> dict[str, Any]:
1658+
to_object: dict[str, Any] = {}
1659+
if getv(from_object, ['bigquery_destination']) is not None:
1660+
setv(
1661+
to_object,
1662+
['bigqueryDestination', 'outputUri'],
1663+
getv(from_object, ['bigquery_destination']),
1664+
)
1665+
1666+
if getv(from_object, ['display_name']) is not None:
1667+
setv(to_object, ['displayName'], getv(from_object, ['display_name']))
1668+
1669+
return to_object
1670+
1671+
15961672
class Batches(_api_module.BaseModule):
15971673

15981674
def _create(
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
17+
"""Tests for batches.create() with Vertex dataset source."""
18+
19+
import re
20+
21+
import pytest
22+
23+
from .. import pytest_helper
24+
from ... import types
25+
26+
27+
_GEMINI_MODEL = 'gemini-2.5-flash'
28+
_GEMINI_MODEL_FULL_NAME = 'publishers/google/models/gemini-2.5-flash'
29+
_OUTPUT_VERTEX_DATASET_DISPLAY_NAME = 'test_batch_output'
30+
_VERTEX_DATASET_INPUT_NAME = (
31+
'projects/vertex-sdk-dev/locations/us-central1/datasets/7857316250517504000'
32+
)
33+
34+
_BQ_OUTPUT_PREFIX = (
35+
'bq://vertex-sdk-dev.unified_genai_tests_batches.generate_content_output'
36+
)
37+
_VERTEX_DATASET_DESTINATION = types.VertexMultimodalDatasetDestination(
38+
bigquery_destination=_BQ_OUTPUT_PREFIX,
39+
display_name=_OUTPUT_VERTEX_DATASET_DISPLAY_NAME,
40+
)
41+
42+
43+
# All tests will be run for both Vertex and MLDev.
44+
test_table: list[pytest_helper.TestTableItem] = [
45+
pytest_helper.TestTableItem(
46+
name='test_union_generate_content_with_vertex_dataset_name',
47+
parameters=types._CreateBatchJobParameters(
48+
model=_GEMINI_MODEL_FULL_NAME,
49+
src=_VERTEX_DATASET_INPUT_NAME,
50+
config={
51+
'dest': {
52+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
53+
'format': 'vertex-dataset',
54+
},
55+
},
56+
),
57+
exception_if_mldev='not supported in Gemini API',
58+
has_union=True,
59+
),
60+
pytest_helper.TestTableItem(
61+
name='test_generate_content_with_vertex_dataset_source',
62+
parameters=types._CreateBatchJobParameters(
63+
model=_GEMINI_MODEL_FULL_NAME,
64+
src=types.BatchJobSource(
65+
vertex_dataset_name=_VERTEX_DATASET_INPUT_NAME,
66+
format='vertex-dataset',
67+
),
68+
config={
69+
'dest': {
70+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
71+
'format': 'vertex-dataset',
72+
},
73+
},
74+
),
75+
exception_if_mldev='one of',
76+
),
77+
pytest_helper.TestTableItem(
78+
name='test_generate_content_with_vertex_dataset_source_dict',
79+
parameters=types._CreateBatchJobParameters(
80+
model=_GEMINI_MODEL_FULL_NAME,
81+
src={
82+
'vertex_dataset_name': _VERTEX_DATASET_INPUT_NAME,
83+
'format': 'vertex-dataset',
84+
},
85+
config={
86+
'dest': {
87+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
88+
'format': 'vertex-dataset',
89+
},
90+
},
91+
),
92+
exception_if_mldev='one of',
93+
),
94+
]
95+
96+
pytestmark = [
97+
pytest.mark.usefixtures('mock_timestamped_unique_name'),
98+
pytest_helper.setup(
99+
file=__file__,
100+
globals_for_file=globals(),
101+
test_method='batches.create',
102+
test_table=test_table,
103+
),
104+
]
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_async_create(client):
109+
with pytest_helper.exception_if_mldev(client, ValueError):
110+
batch_job = await client.aio.batches.create(
111+
model=_GEMINI_MODEL,
112+
src=_VERTEX_DATASET_INPUT_NAME,
113+
config={
114+
'dest': {
115+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
116+
'format': 'vertex-dataset',
117+
},
118+
},
119+
)
120+
121+
assert batch_job.name.startswith('projects/')
122+
assert (
123+
batch_job.model == _GEMINI_MODEL_FULL_NAME
124+
) # Converted to Vertex full name.
125+
assert batch_job.src.vertex_dataset_name == _VERTEX_DATASET_INPUT_NAME
126+
assert batch_job.src.format == 'vertex-dataset'

google/genai/tests/transformers/test_t_batch.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,36 @@ def test_batch_job_source_vertexai_valid_bigquery(self, vertex_client):
172172
result = t.t_batch_job_source(vertex_client, src_obj)
173173
assert result is src_obj
174174

175-
def test_batch_job_source_vertexai_valid_both(self, vertex_client):
175+
def test_batch_job_source_vertexai_valid_all(self, vertex_client):
176176
src_obj = types.BatchJobSource(
177177
gcs_uri=['gs://vertex-bucket/data.jsonl'],
178178
bigquery_uri='bq://project.dataset.table',
179+
vertex_dataset_name='projects/123/locations/us-central1/datasets/456',
179180
)
180-
with pytest.raises(ValueError, match='`gcs_uri` or `bigquery_uri`'):
181+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
182+
t.t_batch_job_source(vertex_client, src_obj)
183+
184+
def test_batch_job_source_vertexai_valid_gcs_and_bigquery(self, vertex_client):
185+
src_obj = types.BatchJobSource(
186+
gcs_uri=['gs://vertex-bucket/data.jsonl'],
187+
bigquery_uri='bq://project.dataset.table',
188+
)
189+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
190+
t.t_batch_job_source(vertex_client, src_obj)
191+
192+
def test_batch_job_source_vertexai_valid_bigquery_and_vertex_dataset(self, vertex_client):
193+
src_obj = types.BatchJobSource(
194+
bigquery_uri='bq://project.dataset.table',
195+
vertex_dataset_name='projects/123/locations/us-central1/datasets/456',
196+
)
197+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
181198
t.t_batch_job_source(vertex_client, src_obj)
182199

183200
def test_batch_job_source_vertexai_invalid_neither_set(self, vertex_client):
184201
src_obj = types.BatchJobSource(
185202
file_name='files/data.csv'
186203
)
187-
with pytest.raises(ValueError, match='`gcs_uri` or `bigquery_uri`'):
204+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
188205
t.t_batch_job_source(vertex_client, src_obj)
189206

190207

google/genai/types.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16099,6 +16099,11 @@ class BatchJobSource(_common.BaseModel):
1609916099
description="""The Gemini Developer API's inlined input data to run batch job.
1610016100
""",
1610116101
)
16102+
vertex_dataset_name: Optional[str] = Field(
16103+
default=None,
16104+
description="""The Vertex AI dataset resource name to use as input. Must be of type multimodal.
16105+
""",
16106+
)
1610216107

1610316108

1610416109
class BatchJobSourceDict(TypedDict, total=False):
@@ -16126,10 +16131,42 @@ class BatchJobSourceDict(TypedDict, total=False):
1612616131
"""The Gemini Developer API's inlined input data to run batch job.
1612716132
"""
1612816133

16134+
vertex_dataset_name: Optional[str]
16135+
"""The Vertex AI dataset resource name to use as input. Must be of type multimodal.
16136+
"""
16137+
1612916138

1613016139
BatchJobSourceOrDict = Union[BatchJobSource, BatchJobSourceDict]
1613116140

1613216141

16142+
class VertexMultimodalDatasetDestination(_common.BaseModel):
16143+
"""TODO"""
16144+
16145+
bigquery_destination: Optional[str] = Field(
16146+
default=None,
16147+
description="""The BigQuery destination for the multimodal dataset.""",
16148+
)
16149+
display_name: Optional[str] = Field(
16150+
default=None,
16151+
description="""The display name of the multimodal dataset.""",
16152+
)
16153+
16154+
16155+
class VertexMultimodalDatasetDestinationDict(TypedDict, total=False):
16156+
"""TODO"""
16157+
16158+
bigquery_destination: Optional[str]
16159+
"""The BigQuery destination for the multimodal dataset."""
16160+
16161+
display_name: Optional[str]
16162+
"""The display name of the multimodal dataset."""
16163+
16164+
16165+
VertexMultimodalDatasetDestinationOrDict = Union[
16166+
VertexMultimodalDatasetDestination, VertexMultimodalDatasetDestinationDict
16167+
]
16168+
16169+
1613316170
class JobError(_common.BaseModel):
1613416171
"""Job error."""
1613516172

@@ -16307,6 +16344,9 @@ class BatchJobDestination(_common.BaseModel):
1630716344
the input requests.
1630816345
""",
1630916346
)
16347+
vertex_dataset: Optional[VertexMultimodalDatasetDestination] = Field(
16348+
default=None, description="""TODO"""
16349+
)
1631016350

1631116351

1631216352
class BatchJobDestinationDict(TypedDict, total=False):
@@ -16347,6 +16387,9 @@ class BatchJobDestinationDict(TypedDict, total=False):
1634716387
the input requests.
1634816388
"""
1634916389

16390+
vertex_dataset: Optional[VertexMultimodalDatasetDestinationDict]
16391+
"""TODO"""
16392+
1635016393

1635116394
BatchJobDestinationOrDict = Union[BatchJobDestination, BatchJobDestinationDict]
1635216395

0 commit comments

Comments
 (0)