Skip to content

Commit d250e14

Browse files
committed
feat: Add ai_generate_bool to the bigframes.bigquery package
1 parent 8804ada commit d250e14

File tree

10 files changed

+388
-0
lines changed

10 files changed

+388
-0
lines changed

bigframes/bigquery/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import sys
2020

21+
from bigframes.bigquery._operations.ai import ai_generate_bool
2122
from bigframes.bigquery._operations.approx_agg import approx_top_count
2223
from bigframes.bigquery._operations.array import (
2324
array_agg,
@@ -57,6 +58,8 @@
5758
from bigframes.core import log_adapter
5859

5960
_functions = [
61+
# ai ops
62+
ai_generate_bool,
6063
# approximate aggregate ops
6164
approx_top_count,
6265
# array ops
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import functools
18+
import json
19+
from typing import Any, List, Literal, Mapping, Sequence, Tuple
20+
21+
from bigframes import clients, dtypes, series
22+
from bigframes.operations import ai_ops
23+
24+
25+
def ai_generate_bool(
26+
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series],
27+
*,
28+
connection_id: str | None = None,
29+
endpoint: str | None = None,
30+
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
31+
model_params: Mapping[Any, Any] | None = None,
32+
) -> series.Series:
33+
"""Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
34+
35+
**Examples:**
36+
37+
>>> import bigframes.pandas as bpd
38+
>>> import bigframes.bigquery as bbq
39+
>>> bpd.options.display.progress_bar = None
40+
>>> df = bpd.DataFrame({
41+
... "col_1": ["apple", "bear", "pear"],
42+
... "col_2": ["fruit", "animal", "animal"]
43+
... })
44+
>>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"]))
45+
0 {'result': True, 'full_response': '{"candidate...
46+
1 {'result': True, 'full_response': '{"candidate...
47+
2 {'result': False, 'full_response': '{"candidat...
48+
dtype: struct<result: bool, full_response: string, status: string>[pyarrow]
49+
50+
>>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result")
51+
0 True
52+
1 True
53+
2 False
54+
Name: result, dtype: boolean
55+
56+
>>> model_params = {
57+
... "generation_config": {
58+
... "thinking_config": {
59+
... "thinking_budget": 0
60+
... }
61+
... }
62+
... }
63+
>>> bbq.ai_generate_bool(
64+
... (df["col_1"], " is a ", df["col_2"]),
65+
... endpoint="gemini-2.5-pro",
66+
... model_params=model_params,
67+
... ).struct.field("result")
68+
0 True
69+
1 True
70+
2 False
71+
Name: result, dtype: boolean
72+
73+
Args:
74+
prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series]):
75+
A mixture of Series and string literals that specifies the prompt to send to the model.
76+
connection_id (str, optional):
77+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
78+
If not provided, the connection from the current session will be used.
79+
endpoint (str, optional):
80+
Specifies the Vertex AI endpoint to use for the model. You can specify any generally available
81+
or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
82+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects
83+
a recent stable version of Gemini to use.
84+
request_type (Literal["dedicated", "shared", "unspecified"]):
85+
Specifies the type of inference request to send to the Gemini model. The request type determines what
86+
quota the request uses.
87+
* "dedicated": function only uses Provisioned Throughput quota. The AI.GENERATE function returns the error Provisioned throughput is not purchased or is not active if Provisioned Throughput quota isn't available.
88+
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
89+
* "unspecified":
90+
* If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
91+
* If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first. If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
92+
model_params (Mapping[Any, Any]):
93+
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
94+
95+
Returns:
96+
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
97+
* "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
98+
* "full_resposne": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model. The generated text is in the text element.
99+
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
100+
"""
101+
102+
if request_type not in ("dedicated", "shared", "unspecified"):
103+
raise ValueError(f"Unsupported request type: {request_type}")
104+
105+
prompt_context, series_list = _separate_context_and_series(prompt)
106+
107+
if not series_list:
108+
raise ValueError("Please provide at least one Series in the prompt")
109+
110+
operator = ai_ops.AIGenerateBool(
111+
prompt_context=tuple(prompt_context),
112+
connection_id=_resolve_connection_id(series_list[0], connection_id),
113+
endpoint=endpoint,
114+
request_type=request_type,
115+
model_params=json.dumps(model_params) if model_params else None,
116+
)
117+
118+
return series_list[0]._apply_nary_op(operator, series_list[1:])
119+
120+
121+
@functools.singledispatch
122+
def _separate_context_and_series(
123+
prompt: Any,
124+
) -> Tuple[List[str | None], List[series.Series]]:
125+
"""
126+
Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
127+
in the prompt. The original item order is kept.
128+
For example:
129+
Input: ("str1", series1, "str2", "str3", series2)
130+
Output: ["str1", None, "str2", "str3", None], [series1, series2]
131+
"""
132+
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
133+
134+
135+
@_separate_context_and_series.register
136+
def _(
137+
prompt: series.Series,
138+
) -> Tuple[List[str | None], List[series.Series]]:
139+
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
140+
# Multi-model support
141+
return [None], [prompt.blob.read_url()]
142+
return [None], [prompt]
143+
144+
145+
@_separate_context_and_series.register(list)
146+
@_separate_context_and_series.register(tuple)
147+
def _(
148+
prompt: Sequence[str | series.Series],
149+
) -> Tuple[List[str | None], List[series.Series]]:
150+
151+
prompt_context: List[str | None] = []
152+
series_list: List[series.Series] = []
153+
154+
for item in prompt:
155+
if isinstance(item, str):
156+
prompt_context.append(item)
157+
158+
elif isinstance(item, series.Series):
159+
prompt_context.append(None)
160+
161+
if item.dtype == dtypes.OBJ_REF_DTYPE:
162+
# Multi-model support
163+
item = item.blob.read_url()
164+
series_list.append(item)
165+
166+
else:
167+
raise ValueError(f"Unsupported type in prompt: {type(item)}")
168+
169+
return prompt_context, series_list
170+
171+
172+
def _resolve_connection_id(series: series.Series, connection_id: str | None):
173+
return clients.get_canonical_bq_connection_id(
174+
connection_id or series._session._bq_connection,
175+
series._session._project,
176+
series._session._location,
177+
)

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import functools
1818
import typing
1919

20+
from bigframes_vendored import ibis
2021
import bigframes_vendored.ibis.expr.api as ibis_api
2122
import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes
23+
import bigframes_vendored.ibis.expr.operations.ai_ops as ai_ops
2224
import bigframes_vendored.ibis.expr.operations.generic as ibis_generic
2325
import bigframes_vendored.ibis.expr.operations.udf as ibis_udf
2426
import bigframes_vendored.ibis.expr.types as ibis_types
@@ -1963,6 +1965,32 @@ def struct_op_impl(
19631965
return ibis_types.struct(data)
19641966

19651967

1968+
@scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True)
1969+
def ai_generate_bool(
1970+
*values: ibis_types.Value, op: ops.AIGenerateBool
1971+
) -> ibis_dtypes.StructValue:
1972+
1973+
prompt = {}
1974+
column_ref_idx = 0
1975+
1976+
for idx, elem in enumerate(op.prompt_context):
1977+
if elem is None:
1978+
value = values[column_ref_idx]
1979+
column_ref_idx += 1
1980+
else:
1981+
value = elem
1982+
1983+
prompt[f"_field_{idx + 1}"] = value
1984+
1985+
return ai_ops.AIGenerateBool(
1986+
ibis.struct(prompt),
1987+
op.connection_id,
1988+
op.endpoint,
1989+
op.request_type.upper(),
1990+
op.model_params,
1991+
).to_expr()
1992+
1993+
19661994
@scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True)
19671995
def rowkey_op_impl(*values: ibis_types.Value, op: ops.RowKey) -> ibis_types.Value:
19681996
return bigframes.core.compile.default_ordering.gen_row_key(values)

bigframes/operations/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from bigframes.operations.ai_ops import AIGenerateBool
1718
from bigframes.operations.array_ops import (
1819
ArrayIndexOp,
1920
ArrayReduceOp,
@@ -408,6 +409,8 @@
408409
"geo_x_op",
409410
"geo_y_op",
410411
"GeoStDistanceOp",
412+
# AI ops
413+
"AIGenerateBool",
411414
# Numpy ops mapping
412415
"NUMPY_TO_BINOP",
413416
"NUMPY_TO_OP",

bigframes/operations/ai_ops.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
from typing import ClassVar, Literal, Tuple
19+
20+
import pandas as pd
21+
import pyarrow as pa
22+
23+
from bigframes import dtypes
24+
from bigframes.operations import base_ops
25+
26+
27+
@dataclasses.dataclass(frozen=True)
28+
class AIGenerateBool(base_ops.NaryOp):
29+
name: ClassVar[str] = "ai_generate_bool"
30+
31+
# None are the placeholders for column references.
32+
prompt_context: Tuple[str | None]
33+
connection_id: str
34+
endpoint: str | None
35+
request_type: Literal["dedicated", "shared", "unspecified"]
36+
model_params: str | None
37+
38+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
39+
return pd.ArrowDtype(
40+
pa.struct(
41+
(
42+
pa.field("result", pa.bool_()),
43+
pa.field("full_response", pa.string()),
44+
pa.field("status", pa.string()),
45+
)
46+
)
47+
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
import pandas.testing
17+
18+
import bigframes.bigquery as bbq
19+
20+
21+
def test_ai_generate_bool_multi_model(session):
22+
df = session.from_glob_path(
23+
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
24+
)
25+
26+
result = bbq.ai_generate_bool((df["image"], " contains an animal")).struct.field(
27+
"result"
28+
)
29+
30+
pandas.testing.assert_series_equal(
31+
result.to_pandas(),
32+
pd.Series([True, True, False, False, False], name="result"),
33+
check_dtype=False,
34+
check_index=False,
35+
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
import pandas.testing
17+
18+
import bigframes.bigquery as bbq
19+
import bigframes.pandas as bpd
20+
21+
22+
def test_ai_generate_bool(session):
23+
s1 = bpd.Series(["apple", "bear"], session=session)
24+
s2 = bpd.Series(["fruit", "tree"], session=session)
25+
prompt = (s1, " is a ", s2)
26+
model_params = {"generation_config": {"thinking_config": {"thinking_budget": 0}}}
27+
28+
result = bbq.ai_generate_bool(
29+
prompt, endpoint="gemini-2.5-flash", model_params=model_params
30+
).struct.field("result")
31+
32+
pandas.testing.assert_series_equal(
33+
result.to_pandas(),
34+
pd.Series([True, False], name="result"),
35+
check_dtype=False,
36+
check_index=False,
37+
)

0 commit comments

Comments
 (0)