1515from __future__ import annotations
1616
1717import datetime
18+ import functools
1819import typing
19- from typing import Literal , Optional , Sequence , Tuple , Union
20+ from typing import Iterable , Literal , Optional , Sequence , Tuple , Union
2021
2122import bigframes_vendored .constants as constants
2223import bigframes_vendored .pandas .core .groupby as vendored_pandas_groupby
3839import bigframes .core .window_spec as window_specs
3940import bigframes .dataframe as df
4041import bigframes .dtypes as dtypes
42+ import bigframes .enums
43+ import bigframes .operations as ops
4144import bigframes .operations .aggregations as agg_ops
4245import bigframes .series as series
4346
@@ -54,6 +57,7 @@ def __init__(
5457 selected_cols : typing .Optional [typing .Sequence [str ]] = None ,
5558 dropna : bool = True ,
5659 as_index : bool = True ,
60+ by_key_is_singular : bool = False ,
5761 ):
5862 # TODO(tbergeron): Support more group-by expression types
5963 self ._block = block
@@ -64,6 +68,9 @@ def __init__(
6468 )
6569 }
6670 self ._by_col_ids = by_col_ids
71+ self ._by_key_is_singular = by_key_is_singular
72+ if by_key_is_singular :
73+ assert len (by_col_ids ) == 1 , "singular key should be exactly one group key"
6774
6875 self ._dropna = dropna
6976 self ._as_index = as_index
@@ -149,6 +156,59 @@ def head(self, n: int = 5) -> df.DataFrame:
149156 )
150157 )
151158
159+ def __iter__ (self ) -> Iterable [Tuple [blocks .Label , df .DataFrame ]]:
160+ original_index_columns = self ._block ._index_columns
161+ original_index_labels = self ._block ._index_labels
162+ by_col_ids = self ._by_col_ids
163+ block = self ._block .reset_index (
164+ level = None ,
165+ # Keep the original index columns so they can be recovered.
166+ drop = False ,
167+ allow_duplicates = True ,
168+ replacement = bigframes .enums .DefaultIndexKind .NULL ,
169+ ).set_index (
170+ by_col_ids ,
171+ # Keep by_col_ids in-place so the ordering doesn't change.
172+ drop = False ,
173+ append = False ,
174+ )
175+ block .cached (
176+ force = True ,
177+ # All DataFrames will be filtered by by_col_ids, so
178+ # force block.cached() to cluster by the new index by explicitly
179+ # setting `session_aware=False`. This will ensure that the filters
180+ # are more efficient.
181+ session_aware = False ,
182+ )
183+ keys_block , _ = block .aggregate (by_col_ids , dropna = self ._dropna )
184+ for chunk in keys_block .to_pandas_batches ():
185+ for by_keys in pd .MultiIndex .from_frame (chunk .index .to_frame ()):
186+ filtered_df = df .DataFrame (
187+ # To ensure the cache is used, filter first, then reset the
188+ # index before yielding the DataFrame.
189+ block .filter (
190+ functools .reduce (
191+ ops .and_op .as_expr ,
192+ (
193+ ops .eq_op .as_expr (by_col , ex .const (by_key ))
194+ for by_col , by_key in zip (by_col_ids , by_keys )
195+ ),
196+ ),
197+ ).set_index (
198+ original_index_columns ,
199+ # We retained by_col_ids in the set_index call above,
200+ # so it's safe to drop the duplicates now.
201+ drop = True ,
202+ append = False ,
203+ index_labels = original_index_labels ,
204+ )
205+ )
206+
207+ if self ._by_key_is_singular :
208+ yield by_keys [0 ], filtered_df
209+ else :
210+ yield by_keys , filtered_df
211+
152212 def size (self ) -> typing .Union [df .DataFrame , series .Series ]:
153213 agg_block , _ = self ._block .aggregate_size (
154214 by_column_ids = self ._by_col_ids ,
0 commit comments