Skip to content

Commit ad370f2

Browse files
committed
test: add ScalarArrowUDFBench and ScalarArrowIterUDFBench
1 parent 9677aae commit ad370f2

1 file changed

Lines changed: 308 additions & 0 deletions

File tree

python/benchmarks/bench_eval_type.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from pyspark.cloudpickle import dumps as cloudpickle_dumps
3636
from pyspark.serializers import write_int, write_long
3737
from pyspark.sql.types import (
38+
BinaryType,
39+
BooleanType,
3840
DoubleType,
3941
IntegerType,
4042
StringType,
@@ -110,6 +112,42 @@ def _build_udf_payload(
110112
write_long(0, buf) # result_id
111113

112114

115+
def _build_scalar_arrow_data(
116+
arrow_batch: pa.RecordBatch,
117+
num_batches: int,
118+
buf: io.BytesIO,
119+
) -> None:
120+
"""Write a plain Arrow IPC stream with *num_batches* copies of *arrow_batch*."""
121+
writer = pa.RecordBatchStreamWriter(buf, arrow_batch.schema)
122+
for _ in range(num_batches):
123+
writer.write_batch(arrow_batch)
124+
writer.close()
125+
126+
127+
def _build_scalar_worker_input(
128+
eval_type: int,
129+
udf_func: Callable[..., Any],
130+
return_type: StructType,
131+
arg_offsets: list[int],
132+
arrow_batch: pa.RecordBatch,
133+
num_batches: int,
134+
) -> bytes:
135+
"""Assemble the full binary stream for scalar (non-grouped) eval types."""
136+
buf = io.BytesIO()
137+
138+
_build_preamble(buf)
139+
write_int(eval_type, buf)
140+
141+
write_int(0, buf) # RunnerConf (0 key-value pairs)
142+
write_int(0, buf) # EvalConf (0 key-value pairs)
143+
144+
_build_udf_payload(udf_func, return_type, arg_offsets, buf)
145+
_build_scalar_arrow_data(arrow_batch, num_batches, buf)
146+
write_int(-4, buf) # SpecialLengths.END_OF_STREAM
147+
148+
return buf.getvalue()
149+
150+
113151
def _build_grouped_arrow_data(
114152
arrow_batch: pa.RecordBatch,
115153
num_groups: int,
@@ -180,6 +218,22 @@ def _build_grouped_arg_offsets(n_cols: int, n_keys: int = 0) -> list[int]:
180218
return [len(offsets)] + offsets
181219

182220

221+
def _make_typed_batch(rows: int, n_cols: int) -> tuple[pa.RecordBatch, IntegerType]:
222+
"""Columns cycling through int64, string, binary, boolean — reflects realistic serde costs."""
223+
type_cycle = [
224+
(lambda r: pa.array(np.random.randint(0, 1000, r, dtype=np.int64)), IntegerType()),
225+
(lambda r: pa.array([f"s{j}" for j in range(r)]), StringType()),
226+
(lambda r: pa.array([f"b{j}".encode() for j in range(r)]), BinaryType()),
227+
(lambda r: pa.array(np.random.choice([True, False], r)), BooleanType()),
228+
]
229+
arrays = [type_cycle[i % len(type_cycle)][0](rows) for i in range(n_cols)]
230+
fields = [StructField(f"col_{i}", type_cycle[i % len(type_cycle)][1]) for i in range(n_cols)]
231+
return (
232+
pa.RecordBatch.from_arrays(arrays, names=[f.name for f in fields]),
233+
IntegerType(),
234+
)
235+
236+
183237
def _make_grouped_batch(rows_per_group: int, n_cols: int) -> tuple[pa.RecordBatch, StructType]:
184238
"""``group_key (int64)`` + ``(n_cols - 1)`` float32 value columns."""
185239
arrays = [pa.array(np.zeros(rows_per_group, dtype=np.int64))] + [
@@ -346,3 +400,257 @@ def time_mixed_types_two_args(self):
346400
def peakmem_mixed_types_two_args(self):
347401
"""Mixed column types, 2-arg UDF with key, 3 rows/group, 1600 groups."""
348402
self._run(self._two_args_input)
403+
404+
405+
# ---------------------------------------------------------------------------
406+
# SQL_SCALAR_ARROW_UDF
407+
# ---------------------------------------------------------------------------
408+
409+
410+
class ScalarArrowUDFBench:
411+
"""Full worker round-trip for ``SQL_SCALAR_ARROW_UDF``."""
412+
413+
def setup(self):
414+
eval_type = PythonEvalType.SQL_SCALAR_ARROW_UDF
415+
416+
# ---- varying batch size (mixed types, identity UDF) ----
417+
# JVM splits at maxRecordsPerBatch (default 10k), so large cases use
418+
# 10k rows with proportionally more batches to keep total rows constant.
419+
for name, (rows, n_cols, num_batches) in {
420+
"small_few": (1_000, 5, 1_500),
421+
"small_many": (1_000, 50, 200),
422+
"large_few": (10_000, 5, 3_500),
423+
"large_many": (10_000, 50, 400),
424+
}.items():
425+
batch, ret_type = _make_typed_batch(rows, n_cols)
426+
setattr(
427+
self,
428+
f"_{name}_input",
429+
_build_scalar_worker_input(
430+
eval_type,
431+
lambda c: c,
432+
ret_type,
433+
[0],
434+
batch,
435+
num_batches=num_batches,
436+
),
437+
)
438+
439+
# ---- compute: arithmetic on two columns ----
440+
compute_arrays = [pa.array(np.random.rand(10_000)) for _ in range(3)]
441+
compute_batch = pa.RecordBatch.from_arrays(
442+
compute_arrays, names=[f"col_{i}" for i in range(3)]
443+
)
444+
445+
def arrow_compute(a, b):
446+
import pyarrow.compute as pc
447+
448+
return pc.add(a, pc.multiply(b, 2))
449+
450+
self._compute_input = _build_scalar_worker_input(
451+
eval_type,
452+
arrow_compute,
453+
DoubleType(),
454+
[0, 1],
455+
compute_batch,
456+
num_batches=500,
457+
)
458+
459+
# ---- mixed types: string manipulation ----
460+
mixed_batch, _ = _make_mixed_batch(3)
461+
462+
def upper_str(s):
463+
import pyarrow.compute as pc
464+
465+
return pc.utf8_upper(s)
466+
467+
self._mixed_input = _build_scalar_worker_input(
468+
eval_type,
469+
upper_str,
470+
StringType(),
471+
[1], # str_col
472+
mixed_batch,
473+
num_batches=1_300,
474+
)
475+
476+
# -- benchmarks ---------------------------------------------------------
477+
478+
def _run(self, input_bytes):
479+
worker_main(io.BytesIO(input_bytes), io.BytesIO())
480+
481+
def time_small_batches_few_cols(self):
482+
"""1k rows/batch, 5 cols, 1500 batches."""
483+
self._run(self._small_few_input)
484+
485+
def peakmem_small_batches_few_cols(self):
486+
"""1k rows/batch, 5 cols, 1500 batches."""
487+
self._run(self._small_few_input)
488+
489+
def time_small_batches_many_cols(self):
490+
"""1k rows/batch, 50 cols, 200 batches."""
491+
self._run(self._small_many_input)
492+
493+
def peakmem_small_batches_many_cols(self):
494+
"""1k rows/batch, 50 cols, 200 batches."""
495+
self._run(self._small_many_input)
496+
497+
def time_large_batches_few_cols(self):
498+
"""10k rows/batch, 5 cols, 3500 batches."""
499+
self._run(self._large_few_input)
500+
501+
def peakmem_large_batches_few_cols(self):
502+
"""10k rows/batch, 5 cols, 3500 batches."""
503+
self._run(self._large_few_input)
504+
505+
def time_large_batches_many_cols(self):
506+
"""10k rows/batch, 50 cols, 400 batches."""
507+
self._run(self._large_many_input)
508+
509+
def peakmem_large_batches_many_cols(self):
510+
"""10k rows/batch, 50 cols, 400 batches."""
511+
self._run(self._large_many_input)
512+
513+
def time_compute(self):
514+
"""10k rows/batch, 3 cols, 500 batches, arithmetic UDF."""
515+
self._run(self._compute_input)
516+
517+
def peakmem_compute(self):
518+
"""10k rows/batch, 3 cols, 500 batches, arithmetic UDF."""
519+
self._run(self._compute_input)
520+
521+
def time_mixed_types(self):
522+
"""Mixed column types, string UDF, 3 rows/batch, 1300 batches."""
523+
self._run(self._mixed_input)
524+
525+
def peakmem_mixed_types(self):
526+
"""Mixed column types, string UDF, 3 rows/batch, 1300 batches."""
527+
self._run(self._mixed_input)
528+
529+
530+
# ---------------------------------------------------------------------------
531+
# SQL_SCALAR_ARROW_ITER_UDF
532+
# ---------------------------------------------------------------------------
533+
534+
535+
class ScalarArrowIterUDFBench:
536+
"""Full worker round-trip for ``SQL_SCALAR_ARROW_ITER_UDF``.
537+
538+
Same Arrow IPC wire format as ``SQL_SCALAR_ARROW_UDF`` but the UDF
539+
receives/returns ``Iterator[pa.Array]`` instead of a single array.
540+
"""
541+
542+
def setup(self):
543+
eval_type = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF
544+
545+
# ---- varying batch size (mixed types, identity UDF) ----
546+
for name, (rows, n_cols, num_batches) in {
547+
"small_few": (1_000, 5, 1_500),
548+
"small_many": (1_000, 50, 200),
549+
"large_few": (10_000, 5, 3_500),
550+
"large_many": (10_000, 50, 400),
551+
}.items():
552+
batch, ret_type = _make_typed_batch(rows, n_cols)
553+
setattr(
554+
self,
555+
f"_{name}_input",
556+
_build_scalar_worker_input(
557+
eval_type,
558+
lambda it: (c for c in it),
559+
ret_type,
560+
[0],
561+
batch,
562+
num_batches=num_batches,
563+
),
564+
)
565+
566+
# ---- compute: arithmetic on two columns ----
567+
compute_arrays = [pa.array(np.random.rand(10_000)) for _ in range(3)]
568+
compute_batch = pa.RecordBatch.from_arrays(
569+
compute_arrays, names=[f"col_{i}" for i in range(3)]
570+
)
571+
572+
def arrow_compute_iter(it):
573+
import pyarrow.compute as pc
574+
575+
for a, b in it:
576+
yield pc.add(a, pc.multiply(b, 2))
577+
578+
self._compute_input = _build_scalar_worker_input(
579+
eval_type,
580+
arrow_compute_iter,
581+
DoubleType(),
582+
[0, 1],
583+
compute_batch,
584+
num_batches=500,
585+
)
586+
587+
# ---- mixed types: string manipulation ----
588+
mixed_batch, _ = _make_mixed_batch(3)
589+
590+
def upper_str_iter(it):
591+
import pyarrow.compute as pc
592+
593+
for s in it:
594+
yield pc.utf8_upper(s)
595+
596+
self._mixed_input = _build_scalar_worker_input(
597+
eval_type,
598+
upper_str_iter,
599+
StringType(),
600+
[1], # str_col
601+
mixed_batch,
602+
num_batches=1_300,
603+
)
604+
605+
# -- benchmarks ---------------------------------------------------------
606+
607+
def _run(self, input_bytes):
608+
worker_main(io.BytesIO(input_bytes), io.BytesIO())
609+
610+
def time_small_batches_few_cols(self):
611+
"""1k rows/batch, 5 cols, 1500 batches."""
612+
self._run(self._small_few_input)
613+
614+
def peakmem_small_batches_few_cols(self):
615+
"""1k rows/batch, 5 cols, 1500 batches."""
616+
self._run(self._small_few_input)
617+
618+
def time_small_batches_many_cols(self):
619+
"""1k rows/batch, 50 cols, 200 batches."""
620+
self._run(self._small_many_input)
621+
622+
def peakmem_small_batches_many_cols(self):
623+
"""1k rows/batch, 50 cols, 200 batches."""
624+
self._run(self._small_many_input)
625+
626+
def time_large_batches_few_cols(self):
627+
"""10k rows/batch, 5 cols, 3500 batches."""
628+
self._run(self._large_few_input)
629+
630+
def peakmem_large_batches_few_cols(self):
631+
"""10k rows/batch, 5 cols, 3500 batches."""
632+
self._run(self._large_few_input)
633+
634+
def time_large_batches_many_cols(self):
635+
"""10k rows/batch, 50 cols, 400 batches."""
636+
self._run(self._large_many_input)
637+
638+
def peakmem_large_batches_many_cols(self):
639+
"""10k rows/batch, 50 cols, 400 batches."""
640+
self._run(self._large_many_input)
641+
642+
def time_compute(self):
643+
"""10k rows/batch, 3 cols, 500 batches, arithmetic UDF."""
644+
self._run(self._compute_input)
645+
646+
def peakmem_compute(self):
647+
"""10k rows/batch, 3 cols, 500 batches, arithmetic UDF."""
648+
self._run(self._compute_input)
649+
650+
def time_mixed_types(self):
651+
"""Mixed column types, string UDF, 3 rows/batch, 1300 batches."""
652+
self._run(self._mixed_input)
653+
654+
def peakmem_mixed_types(self):
655+
"""Mixed column types, string UDF, 3 rows/batch, 1300 batches."""
656+
self._run(self._mixed_input)

0 commit comments

Comments
 (0)