|
35 | 35 | from pyspark.cloudpickle import dumps as cloudpickle_dumps |
36 | 36 | from pyspark.serializers import write_int, write_long |
37 | 37 | from pyspark.sql.types import ( |
| 38 | + BinaryType, |
| 39 | + BooleanType, |
38 | 40 | DoubleType, |
39 | 41 | IntegerType, |
40 | 42 | StringType, |
@@ -110,6 +112,42 @@ def _build_udf_payload( |
110 | 112 | write_long(0, buf) # result_id |
111 | 113 |
|
112 | 114 |
|
| 115 | +def _build_scalar_arrow_data( |
| 116 | + arrow_batch: pa.RecordBatch, |
| 117 | + num_batches: int, |
| 118 | + buf: io.BytesIO, |
| 119 | +) -> None: |
| 120 | + """Write a plain Arrow IPC stream with *num_batches* copies of *arrow_batch*.""" |
| 121 | + writer = pa.RecordBatchStreamWriter(buf, arrow_batch.schema) |
| 122 | + for _ in range(num_batches): |
| 123 | + writer.write_batch(arrow_batch) |
| 124 | + writer.close() |
| 125 | + |
| 126 | + |
| 127 | +def _build_scalar_worker_input( |
| 128 | + eval_type: int, |
| 129 | + udf_func: Callable[..., Any], |
| 130 | + return_type: StructType, |
| 131 | + arg_offsets: list[int], |
| 132 | + arrow_batch: pa.RecordBatch, |
| 133 | + num_batches: int, |
| 134 | +) -> bytes: |
| 135 | + """Assemble the full binary stream for scalar (non-grouped) eval types.""" |
| 136 | + buf = io.BytesIO() |
| 137 | + |
| 138 | + _build_preamble(buf) |
| 139 | + write_int(eval_type, buf) |
| 140 | + |
| 141 | + write_int(0, buf) # RunnerConf (0 key-value pairs) |
| 142 | + write_int(0, buf) # EvalConf (0 key-value pairs) |
| 143 | + |
| 144 | + _build_udf_payload(udf_func, return_type, arg_offsets, buf) |
| 145 | + _build_scalar_arrow_data(arrow_batch, num_batches, buf) |
| 146 | + write_int(-4, buf) # SpecialLengths.END_OF_STREAM |
| 147 | + |
| 148 | + return buf.getvalue() |
| 149 | + |
| 150 | + |
113 | 151 | def _build_grouped_arrow_data( |
114 | 152 | arrow_batch: pa.RecordBatch, |
115 | 153 | num_groups: int, |
@@ -180,6 +218,22 @@ def _build_grouped_arg_offsets(n_cols: int, n_keys: int = 0) -> list[int]: |
180 | 218 | return [len(offsets)] + offsets |
181 | 219 |
|
182 | 220 |
|
| 221 | +def _make_typed_batch(rows: int, n_cols: int) -> tuple[pa.RecordBatch, IntegerType]: |
| 222 | + """Columns cycling through int64, string, binary, boolean — reflects realistic serde costs.""" |
| 223 | + type_cycle = [ |
| 224 | + (lambda r: pa.array(np.random.randint(0, 1000, r, dtype=np.int64)), IntegerType()), |
| 225 | + (lambda r: pa.array([f"s{j}" for j in range(r)]), StringType()), |
| 226 | + (lambda r: pa.array([f"b{j}".encode() for j in range(r)]), BinaryType()), |
| 227 | + (lambda r: pa.array(np.random.choice([True, False], r)), BooleanType()), |
| 228 | + ] |
| 229 | + arrays = [type_cycle[i % len(type_cycle)][0](rows) for i in range(n_cols)] |
| 230 | + fields = [StructField(f"col_{i}", type_cycle[i % len(type_cycle)][1]) for i in range(n_cols)] |
| 231 | + return ( |
| 232 | + pa.RecordBatch.from_arrays(arrays, names=[f.name for f in fields]), |
| 233 | + IntegerType(), |
| 234 | + ) |
| 235 | + |
| 236 | + |
183 | 237 | def _make_grouped_batch(rows_per_group: int, n_cols: int) -> tuple[pa.RecordBatch, StructType]: |
184 | 238 | """``group_key (int64)`` + ``(n_cols - 1)`` float32 value columns.""" |
185 | 239 | arrays = [pa.array(np.zeros(rows_per_group, dtype=np.int64))] + [ |
@@ -346,3 +400,257 @@ def time_mixed_types_two_args(self): |
346 | 400 | def peakmem_mixed_types_two_args(self): |
347 | 401 | """Mixed column types, 2-arg UDF with key, 3 rows/group, 1600 groups.""" |
348 | 402 | self._run(self._two_args_input) |
| 403 | + |
| 404 | + |
| 405 | +# --------------------------------------------------------------------------- |
| 406 | +# SQL_SCALAR_ARROW_UDF |
| 407 | +# --------------------------------------------------------------------------- |
| 408 | + |
| 409 | + |
| 410 | +class ScalarArrowUDFBench: |
| 411 | + """Full worker round-trip for ``SQL_SCALAR_ARROW_UDF``.""" |
| 412 | + |
| 413 | + def setup(self): |
| 414 | + eval_type = PythonEvalType.SQL_SCALAR_ARROW_UDF |
| 415 | + |
| 416 | + # ---- varying batch size (mixed types, identity UDF) ---- |
| 417 | + # JVM splits at maxRecordsPerBatch (default 10k), so large cases use |
| 418 | + # 10k rows with proportionally more batches to keep total rows constant. |
| 419 | + for name, (rows, n_cols, num_batches) in { |
| 420 | + "small_few": (1_000, 5, 1_500), |
| 421 | + "small_many": (1_000, 50, 200), |
| 422 | + "large_few": (10_000, 5, 3_500), |
| 423 | + "large_many": (10_000, 50, 400), |
| 424 | + }.items(): |
| 425 | + batch, ret_type = _make_typed_batch(rows, n_cols) |
| 426 | + setattr( |
| 427 | + self, |
| 428 | + f"_{name}_input", |
| 429 | + _build_scalar_worker_input( |
| 430 | + eval_type, |
| 431 | + lambda c: c, |
| 432 | + ret_type, |
| 433 | + [0], |
| 434 | + batch, |
| 435 | + num_batches=num_batches, |
| 436 | + ), |
| 437 | + ) |
| 438 | + |
| 439 | + # ---- compute: arithmetic on two columns ---- |
| 440 | + compute_arrays = [pa.array(np.random.rand(10_000)) for _ in range(3)] |
| 441 | + compute_batch = pa.RecordBatch.from_arrays( |
| 442 | + compute_arrays, names=[f"col_{i}" for i in range(3)] |
| 443 | + ) |
| 444 | + |
| 445 | + def arrow_compute(a, b): |
| 446 | + import pyarrow.compute as pc |
| 447 | + |
| 448 | + return pc.add(a, pc.multiply(b, 2)) |
| 449 | + |
| 450 | + self._compute_input = _build_scalar_worker_input( |
| 451 | + eval_type, |
| 452 | + arrow_compute, |
| 453 | + DoubleType(), |
| 454 | + [0, 1], |
| 455 | + compute_batch, |
| 456 | + num_batches=500, |
| 457 | + ) |
| 458 | + |
| 459 | + # ---- mixed types: string manipulation ---- |
| 460 | + mixed_batch, _ = _make_mixed_batch(3) |
| 461 | + |
| 462 | + def upper_str(s): |
| 463 | + import pyarrow.compute as pc |
| 464 | + |
| 465 | + return pc.utf8_upper(s) |
| 466 | + |
| 467 | + self._mixed_input = _build_scalar_worker_input( |
| 468 | + eval_type, |
| 469 | + upper_str, |
| 470 | + StringType(), |
| 471 | + [1], # str_col |
| 472 | + mixed_batch, |
| 473 | + num_batches=1_300, |
| 474 | + ) |
| 475 | + |
| 476 | + # -- benchmarks --------------------------------------------------------- |
| 477 | + |
| 478 | + def _run(self, input_bytes): |
| 479 | + worker_main(io.BytesIO(input_bytes), io.BytesIO()) |
| 480 | + |
| 481 | + def time_small_batches_few_cols(self): |
| 482 | + """1k rows/batch, 5 cols, 1500 batches.""" |
| 483 | + self._run(self._small_few_input) |
| 484 | + |
| 485 | + def peakmem_small_batches_few_cols(self): |
| 486 | + """1k rows/batch, 5 cols, 1500 batches.""" |
| 487 | + self._run(self._small_few_input) |
| 488 | + |
| 489 | + def time_small_batches_many_cols(self): |
| 490 | + """1k rows/batch, 50 cols, 200 batches.""" |
| 491 | + self._run(self._small_many_input) |
| 492 | + |
| 493 | + def peakmem_small_batches_many_cols(self): |
| 494 | + """1k rows/batch, 50 cols, 200 batches.""" |
| 495 | + self._run(self._small_many_input) |
| 496 | + |
| 497 | + def time_large_batches_few_cols(self): |
| 498 | + """10k rows/batch, 5 cols, 3500 batches.""" |
| 499 | + self._run(self._large_few_input) |
| 500 | + |
| 501 | + def peakmem_large_batches_few_cols(self): |
| 502 | + """10k rows/batch, 5 cols, 3500 batches.""" |
| 503 | + self._run(self._large_few_input) |
| 504 | + |
| 505 | + def time_large_batches_many_cols(self): |
| 506 | + """10k rows/batch, 50 cols, 400 batches.""" |
| 507 | + self._run(self._large_many_input) |
| 508 | + |
| 509 | + def peakmem_large_batches_many_cols(self): |
| 510 | + """10k rows/batch, 50 cols, 400 batches.""" |
| 511 | + self._run(self._large_many_input) |
| 512 | + |
| 513 | + def time_compute(self): |
| 514 | + """10k rows/batch, 3 cols, 500 batches, arithmetic UDF.""" |
| 515 | + self._run(self._compute_input) |
| 516 | + |
| 517 | + def peakmem_compute(self): |
| 518 | + """10k rows/batch, 3 cols, 500 batches, arithmetic UDF.""" |
| 519 | + self._run(self._compute_input) |
| 520 | + |
| 521 | + def time_mixed_types(self): |
| 522 | + """Mixed column types, string UDF, 3 rows/batch, 1300 batches.""" |
| 523 | + self._run(self._mixed_input) |
| 524 | + |
| 525 | + def peakmem_mixed_types(self): |
| 526 | + """Mixed column types, string UDF, 3 rows/batch, 1300 batches.""" |
| 527 | + self._run(self._mixed_input) |
| 528 | + |
| 529 | + |
| 530 | +# --------------------------------------------------------------------------- |
| 531 | +# SQL_SCALAR_ARROW_ITER_UDF |
| 532 | +# --------------------------------------------------------------------------- |
| 533 | + |
| 534 | + |
| 535 | +class ScalarArrowIterUDFBench: |
| 536 | + """Full worker round-trip for ``SQL_SCALAR_ARROW_ITER_UDF``. |
| 537 | +
|
| 538 | + Same Arrow IPC wire format as ``SQL_SCALAR_ARROW_UDF`` but the UDF |
| 539 | + receives/returns ``Iterator[pa.Array]`` instead of a single array. |
| 540 | + """ |
| 541 | + |
| 542 | + def setup(self): |
| 543 | + eval_type = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF |
| 544 | + |
| 545 | + # ---- varying batch size (mixed types, identity UDF) ---- |
| 546 | + for name, (rows, n_cols, num_batches) in { |
| 547 | + "small_few": (1_000, 5, 1_500), |
| 548 | + "small_many": (1_000, 50, 200), |
| 549 | + "large_few": (10_000, 5, 3_500), |
| 550 | + "large_many": (10_000, 50, 400), |
| 551 | + }.items(): |
| 552 | + batch, ret_type = _make_typed_batch(rows, n_cols) |
| 553 | + setattr( |
| 554 | + self, |
| 555 | + f"_{name}_input", |
| 556 | + _build_scalar_worker_input( |
| 557 | + eval_type, |
| 558 | + lambda it: (c for c in it), |
| 559 | + ret_type, |
| 560 | + [0], |
| 561 | + batch, |
| 562 | + num_batches=num_batches, |
| 563 | + ), |
| 564 | + ) |
| 565 | + |
| 566 | + # ---- compute: arithmetic on two columns ---- |
| 567 | + compute_arrays = [pa.array(np.random.rand(10_000)) for _ in range(3)] |
| 568 | + compute_batch = pa.RecordBatch.from_arrays( |
| 569 | + compute_arrays, names=[f"col_{i}" for i in range(3)] |
| 570 | + ) |
| 571 | + |
| 572 | + def arrow_compute_iter(it): |
| 573 | + import pyarrow.compute as pc |
| 574 | + |
| 575 | + for a, b in it: |
| 576 | + yield pc.add(a, pc.multiply(b, 2)) |
| 577 | + |
| 578 | + self._compute_input = _build_scalar_worker_input( |
| 579 | + eval_type, |
| 580 | + arrow_compute_iter, |
| 581 | + DoubleType(), |
| 582 | + [0, 1], |
| 583 | + compute_batch, |
| 584 | + num_batches=500, |
| 585 | + ) |
| 586 | + |
| 587 | + # ---- mixed types: string manipulation ---- |
| 588 | + mixed_batch, _ = _make_mixed_batch(3) |
| 589 | + |
| 590 | + def upper_str_iter(it): |
| 591 | + import pyarrow.compute as pc |
| 592 | + |
| 593 | + for s in it: |
| 594 | + yield pc.utf8_upper(s) |
| 595 | + |
| 596 | + self._mixed_input = _build_scalar_worker_input( |
| 597 | + eval_type, |
| 598 | + upper_str_iter, |
| 599 | + StringType(), |
| 600 | + [1], # str_col |
| 601 | + mixed_batch, |
| 602 | + num_batches=1_300, |
| 603 | + ) |
| 604 | + |
| 605 | + # -- benchmarks --------------------------------------------------------- |
| 606 | + |
| 607 | + def _run(self, input_bytes): |
| 608 | + worker_main(io.BytesIO(input_bytes), io.BytesIO()) |
| 609 | + |
| 610 | + def time_small_batches_few_cols(self): |
| 611 | + """1k rows/batch, 5 cols, 1500 batches.""" |
| 612 | + self._run(self._small_few_input) |
| 613 | + |
| 614 | + def peakmem_small_batches_few_cols(self): |
| 615 | + """1k rows/batch, 5 cols, 1500 batches.""" |
| 616 | + self._run(self._small_few_input) |
| 617 | + |
| 618 | + def time_small_batches_many_cols(self): |
| 619 | + """1k rows/batch, 50 cols, 200 batches.""" |
| 620 | + self._run(self._small_many_input) |
| 621 | + |
| 622 | + def peakmem_small_batches_many_cols(self): |
| 623 | + """1k rows/batch, 50 cols, 200 batches.""" |
| 624 | + self._run(self._small_many_input) |
| 625 | + |
| 626 | + def time_large_batches_few_cols(self): |
| 627 | + """10k rows/batch, 5 cols, 3500 batches.""" |
| 628 | + self._run(self._large_few_input) |
| 629 | + |
| 630 | + def peakmem_large_batches_few_cols(self): |
| 631 | + """10k rows/batch, 5 cols, 3500 batches.""" |
| 632 | + self._run(self._large_few_input) |
| 633 | + |
| 634 | + def time_large_batches_many_cols(self): |
| 635 | + """10k rows/batch, 50 cols, 400 batches.""" |
| 636 | + self._run(self._large_many_input) |
| 637 | + |
| 638 | + def peakmem_large_batches_many_cols(self): |
| 639 | + """10k rows/batch, 50 cols, 400 batches.""" |
| 640 | + self._run(self._large_many_input) |
| 641 | + |
| 642 | + def time_compute(self): |
| 643 | + """10k rows/batch, 3 cols, 500 batches, arithmetic UDF.""" |
| 644 | + self._run(self._compute_input) |
| 645 | + |
| 646 | + def peakmem_compute(self): |
| 647 | + """10k rows/batch, 3 cols, 500 batches, arithmetic UDF.""" |
| 648 | + self._run(self._compute_input) |
| 649 | + |
| 650 | + def time_mixed_types(self): |
| 651 | + """Mixed column types, string UDF, 3 rows/batch, 1300 batches.""" |
| 652 | + self._run(self._mixed_input) |
| 653 | + |
| 654 | + def peakmem_mixed_types(self): |
| 655 | + """Mixed column types, string UDF, 3 rows/batch, 1300 batches.""" |
| 656 | + self._run(self._mixed_input) |
0 commit comments