Skip to content

Commit 500d1a4

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: Add Vertex Dataset input and output options for batch jobs
PiperOrigin-RevId: 898998803
1 parent 3f36ca1 commit 500d1a4

5 files changed

Lines changed: 285 additions & 5 deletions

File tree

google/genai/_transformers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,11 @@ def t_batch_job_source(
10121012
src = types.BatchJobSource(**src)
10131013
if is_duck_type_of(src, types.BatchJobSource):
10141014
vertex_sources = sum(
1015-
[src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr]
1015+
[
1016+
src.gcs_uri is not None,
1017+
src.bigquery_uri is not None,
1018+
src.vertex_dataset_name is not None,
1019+
] # type: ignore[union-attr]
10161020
)
10171021
mldev_sources = sum([
10181022
src.inlined_requests is not None, # type: ignore[union-attr]
@@ -1021,7 +1025,7 @@ def t_batch_job_source(
10211025
if client.vertexai:
10221026
if mldev_sources or vertex_sources != 1:
10231027
raise ValueError(
1024-
'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
1028+
'Exactly one of `gcs_uri`, `bigquery_uri`, or `vertex_dataset_name` must be set, other '
10251029
'sources are not supported in Vertex AI.'
10261030
)
10271031
else:
@@ -1046,6 +1050,11 @@ def t_batch_job_source(
10461050
format='bigquery',
10471051
bigquery_uri=src,
10481052
)
1053+
elif re.match(r'^projects/[^/]+/locations/[^/]+/datasets/[^/]+$', src):
1054+
return types.BatchJobSource(
1055+
format='vertex-dataset',
1056+
vertex_dataset_name=src,
1057+
)
10491058
elif src.startswith('files/'):
10501059
return types.BatchJobSource(
10511060
file_name=src,

google/genai/batches.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ def _BatchJobDestination_from_vertex(
130130
getv(from_object, ['bigqueryDestination', 'outputUri']),
131131
)
132132

133+
if getv(from_object, ['vertexMultimodalDatasetDestination']) is not None:
134+
setv(
135+
to_object,
136+
['vertex_dataset'],
137+
_VertexMultimodalDatasetDestination_from_vertex(
138+
getv(from_object, ['vertexMultimodalDatasetDestination']), to_object
139+
),
140+
)
141+
133142
return to_object
134143

135144

@@ -169,6 +178,15 @@ def _BatchJobDestination_to_vertex(
169178
' Vertex AI.'
170179
)
171180

181+
if getv(from_object, ['vertex_dataset']) is not None:
182+
setv(
183+
to_object,
184+
['vertexMultimodalDatasetDestination'],
185+
_VertexMultimodalDatasetDestination_to_vertex(
186+
getv(from_object, ['vertex_dataset']), to_object
187+
),
188+
)
189+
172190
return to_object
173191

174192

@@ -190,6 +208,16 @@ def _BatchJobSource_from_vertex(
190208
getv(from_object, ['bigquerySource', 'inputUri']),
191209
)
192210

211+
if (
212+
getv(from_object, ['vertexMultimodalDatasetSource', 'datasetName'])
213+
is not None
214+
):
215+
setv(
216+
to_object,
217+
['vertex_dataset_name'],
218+
getv(from_object, ['vertexMultimodalDatasetSource', 'datasetName']),
219+
)
220+
193221
return to_object
194222

195223

@@ -221,6 +249,11 @@ def _BatchJobSource_to_mldev(
221249
],
222250
)
223251

252+
if getv(from_object, ['vertex_dataset_name']) is not None:
253+
raise ValueError(
254+
'vertex_dataset_name parameter is not supported in Gemini API.'
255+
)
256+
224257
return to_object
225258

226259

@@ -250,6 +283,13 @@ def _BatchJobSource_to_vertex(
250283
'inlined_requests parameter is not supported in Vertex AI.'
251284
)
252285

286+
if getv(from_object, ['vertex_dataset_name']) is not None:
287+
setv(
288+
to_object,
289+
['vertexMultimodalDatasetSource', 'datasetName'],
290+
getv(from_object, ['vertex_dataset_name']),
291+
)
292+
253293
return to_object
254294

255295

@@ -1603,6 +1643,42 @@ def _Tool_to_mldev(
16031643
return to_object
16041644

16051645

1646+
def _VertexMultimodalDatasetDestination_from_vertex(
1647+
from_object: Union[dict[str, Any], object],
1648+
parent_object: Optional[dict[str, Any]] = None,
1649+
) -> dict[str, Any]:
1650+
to_object: dict[str, Any] = {}
1651+
if getv(from_object, ['bigqueryDestination', 'outputUri']) is not None:
1652+
setv(
1653+
to_object,
1654+
['bigquery_destination'],
1655+
getv(from_object, ['bigqueryDestination', 'outputUri']),
1656+
)
1657+
1658+
if getv(from_object, ['displayName']) is not None:
1659+
setv(to_object, ['display_name'], getv(from_object, ['displayName']))
1660+
1661+
return to_object
1662+
1663+
1664+
def _VertexMultimodalDatasetDestination_to_vertex(
1665+
from_object: Union[dict[str, Any], object],
1666+
parent_object: Optional[dict[str, Any]] = None,
1667+
) -> dict[str, Any]:
1668+
to_object: dict[str, Any] = {}
1669+
if getv(from_object, ['bigquery_destination']) is not None:
1670+
setv(
1671+
to_object,
1672+
['bigqueryDestination', 'outputUri'],
1673+
getv(from_object, ['bigquery_destination']),
1674+
)
1675+
1676+
if getv(from_object, ['display_name']) is not None:
1677+
setv(to_object, ['displayName'], getv(from_object, ['display_name']))
1678+
1679+
return to_object
1680+
1681+
16061682
class Batches(_api_module.BaseModule):
16071683

16081684
def _create(
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
17+
"""Tests for batches.create() with Vertex dataset source."""
18+
19+
import re
20+
21+
import pytest
22+
23+
from .. import pytest_helper
24+
from ... import types
25+
26+
27+
_GEMINI_MODEL = 'gemini-2.5-flash'
28+
_GEMINI_MODEL_FULL_NAME = 'publishers/google/models/gemini-2.5-flash'
29+
_OUTPUT_VERTEX_DATASET_DISPLAY_NAME = 'test_batch_output'
30+
_VERTEX_DATASET_INPUT_NAME = (
31+
'projects/vertex-sdk-dev/locations/us-central1/datasets/7857316250517504000'
32+
)
33+
34+
_BQ_OUTPUT_PREFIX = (
35+
'bq://vertex-sdk-dev.unified_genai_tests_batches.generate_content_output'
36+
)
37+
_VERTEX_DATASET_DESTINATION = types.VertexMultimodalDatasetDestination(
38+
bigquery_destination=_BQ_OUTPUT_PREFIX,
39+
display_name=_OUTPUT_VERTEX_DATASET_DISPLAY_NAME,
40+
)
41+
42+
43+
# All tests will be run for both Vertex and MLDev.
44+
test_table: list[pytest_helper.TestTableItem] = [
45+
pytest_helper.TestTableItem(
46+
name='test_union_generate_content_with_vertex_dataset_name',
47+
parameters=types._CreateBatchJobParameters(
48+
model=_GEMINI_MODEL_FULL_NAME,
49+
src=_VERTEX_DATASET_INPUT_NAME,
50+
config={
51+
'dest': {
52+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
53+
'format': 'vertex-dataset',
54+
},
55+
},
56+
),
57+
exception_if_mldev='not supported in Gemini API',
58+
has_union=True,
59+
),
60+
pytest_helper.TestTableItem(
61+
name='test_generate_content_with_vertex_dataset_source',
62+
parameters=types._CreateBatchJobParameters(
63+
model=_GEMINI_MODEL_FULL_NAME,
64+
src=types.BatchJobSource(
65+
vertex_dataset_name=_VERTEX_DATASET_INPUT_NAME,
66+
format='vertex-dataset',
67+
),
68+
config={
69+
'dest': {
70+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
71+
'format': 'vertex-dataset',
72+
},
73+
},
74+
),
75+
exception_if_mldev='one of',
76+
),
77+
pytest_helper.TestTableItem(
78+
name='test_generate_content_with_vertex_dataset_source_dict',
79+
parameters=types._CreateBatchJobParameters(
80+
model=_GEMINI_MODEL_FULL_NAME,
81+
src={
82+
'vertex_dataset_name': _VERTEX_DATASET_INPUT_NAME,
83+
'format': 'vertex-dataset',
84+
},
85+
config={
86+
'dest': {
87+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
88+
'format': 'vertex-dataset',
89+
},
90+
},
91+
),
92+
exception_if_mldev='one of',
93+
),
94+
]
95+
96+
pytestmark = [
97+
pytest.mark.usefixtures('mock_timestamped_unique_name'),
98+
pytest_helper.setup(
99+
file=__file__,
100+
globals_for_file=globals(),
101+
test_method='batches.create',
102+
test_table=test_table,
103+
),
104+
]
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_async_create(client):
109+
with pytest_helper.exception_if_mldev(client, ValueError):
110+
batch_job = await client.aio.batches.create(
111+
model=_GEMINI_MODEL,
112+
src=_VERTEX_DATASET_INPUT_NAME,
113+
config={
114+
'dest': {
115+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
116+
'format': 'vertex-dataset',
117+
},
118+
},
119+
)
120+
121+
assert batch_job.name.startswith('projects/')
122+
assert (
123+
batch_job.model == _GEMINI_MODEL_FULL_NAME
124+
) # Converted to Vertex full name.
125+
assert batch_job.src.vertex_dataset_name == _VERTEX_DATASET_INPUT_NAME
126+
assert batch_job.src.format == 'vertex-dataset'

google/genai/tests/transformers/test_t_batch.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,36 @@ def test_batch_job_source_vertexai_valid_bigquery(self, vertex_client):
172172
result = t.t_batch_job_source(vertex_client, src_obj)
173173
assert result is src_obj
174174

175-
def test_batch_job_source_vertexai_valid_both(self, vertex_client):
175+
def test_batch_job_source_vertexai_valid_all(self, vertex_client):
176176
src_obj = types.BatchJobSource(
177177
gcs_uri=['gs://vertex-bucket/data.jsonl'],
178178
bigquery_uri='bq://project.dataset.table',
179+
vertex_dataset_name='projects/123/locations/us-central1/datasets/456',
179180
)
180-
with pytest.raises(ValueError, match='`gcs_uri` or `bigquery_uri`'):
181+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
182+
t.t_batch_job_source(vertex_client, src_obj)
183+
184+
def test_batch_job_source_vertexai_valid_gcs_and_bigquery(self, vertex_client):
185+
src_obj = types.BatchJobSource(
186+
gcs_uri=['gs://vertex-bucket/data.jsonl'],
187+
bigquery_uri='bq://project.dataset.table',
188+
)
189+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
190+
t.t_batch_job_source(vertex_client, src_obj)
191+
192+
def test_batch_job_source_vertexai_valid_bigquery_and_vertex_dataset(self, vertex_client):
193+
src_obj = types.BatchJobSource(
194+
bigquery_uri='bq://project.dataset.table',
195+
vertex_dataset_name='projects/123/locations/us-central1/datasets/456',
196+
)
197+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
181198
t.t_batch_job_source(vertex_client, src_obj)
182199

183200
def test_batch_job_source_vertexai_invalid_neither_set(self, vertex_client):
184201
src_obj = types.BatchJobSource(
185202
file_name='files/data.csv'
186203
)
187-
with pytest.raises(ValueError, match='`gcs_uri` or `bigquery_uri`'):
204+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
188205
t.t_batch_job_source(vertex_client, src_obj)
189206

190207

0 commit comments

Comments
 (0)