Skip to content

Commit 9befcc0

Browse files
committed
refactor: removal of processing pool for duckdb data contract
1 parent c2ad557 commit 9befcc0

File tree

6 files changed

+168
-205
lines changed

6 files changed

+168
-205
lines changed

poetry.lock

Lines changed: 154 additions & 164 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dve/core_engine/backends/implementations/duckdb/contract.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
# pylint: disable=R0903
44
import logging
55
from collections.abc import Iterator
6-
from concurrent.futures import Future, ProcessPoolExecutor, as_completed
76
from functools import partial
8-
from multiprocessing import cpu_count
97
from typing import Any, Optional
108
from uuid import uuid4
119

@@ -71,12 +69,10 @@ def __init__(
7169
connection: DuckDBPyConnection,
7270
logger: Optional[logging.Logger] = None,
7371
debug: bool = False,
74-
executor: Optional[ProcessPoolExecutor] = None,
7572
**kwargs: Any,
7673
):
7774
self.debug = debug
7875
self._connection = connection
79-
self._executor = ProcessPoolExecutor(cpu_count() - 1) if not executor else executor
8076
"""A bool indicating whether to enable debug logging."""
8177

8278
super().__init__(logger, **kwargs)
@@ -167,11 +163,8 @@ def apply_data_contract(
167163

168164
batches = pq.ParquetFile(entity_locations[entity_name]).iter_batches(10000)
169165
msg_count = 0
170-
futures: list[Future] = [
171-
self._executor.submit(row_validator_helper, batch) for batch in batches
172-
]
173-
for future in as_completed(futures):
174-
if msgs := future.result():
166+
for batch in batches:
167+
if msgs := row_validator_helper(arrow_batch=batch):
175168
msg_writer.write_queue.put(msgs)
176169
msg_count += len(msgs)
177170

src/dve/pipeline/duckdb_pipeline.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""DuckDB implementation for `Pipeline` object."""
22

33
import logging
4-
from concurrent.futures import ProcessPoolExecutor
54
from typing import Optional
65

76
from duckdb import DuckDBPyConnection, DuckDBPyRelation
@@ -34,13 +33,12 @@ def __init__(
3433
reference_data_loader: Optional[type[BaseRefDataLoader]] = None,
3534
job_run_id: Optional[int] = None,
3635
logger: Optional[logging.Logger] = None,
37-
executor: Optional[ProcessPoolExecutor] = None,
3836
):
3937
self._connection = connection
4038
super().__init__(
4139
processed_files_path,
4240
audit_tables,
43-
DuckDBDataContract(connection=self._connection, executor=executor),
41+
DuckDBDataContract(connection=self._connection),
4442
DuckDBStepImplementations.register_udfs(connection=self._connection),
4543
rules_path,
4644
submitted_files_path,

tests/features/environment.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from concurrent.futures import ProcessPoolExecutor
2-
from multiprocessing import cpu_count
31
import shutil
42
import tempfile
53
from pathlib import Path
@@ -29,7 +27,6 @@ def before_all(context: Context):
2927
temp_dir = Path(context.dbfs_root.__enter__())
3028
dbfs_impl = DBFSFilesystemImplementation(temp_dir)
3129
add_implementation(dbfs_impl)
32-
context.process_pool = ProcessPoolExecutor(cpu_count() - 1)
3330

3431

3532
def before_scenario(context: Context, scenario: Scenario):
@@ -81,4 +78,3 @@ def after_all(context: Context):
8178

8279
context.connection.close()
8380
shutil.rmtree(context.ddb_db_file.parent)
84-
context.process_pool.shutdown(wait=True, cancel_futures=True)

tests/features/steps/steps_pipeline.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
77
"""
88
# pylint: disable=no-name-in-module
9-
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
9+
from concurrent.futures import ThreadPoolExecutor
1010
from functools import partial, reduce
11-
from itertools import chain
1211
import operator
1312
from pathlib import Path
14-
from typing import Any, Callable, Dict, List, Optional, Tuple
13+
from typing import Callable, Dict, List, Optional, Tuple
1514
from uuid import uuid4
1615
from behave import given, then, when # type: ignore
1716
from behave.model import Row, Table
@@ -23,19 +22,16 @@
2322

2423
import context_tools as ctxt
2524
import dve.parser.file_handling.service as fh
26-
from dve.pipeline.utils import SubmissionStatus, load_config
2725

2826
import polars as pl
2927
from pyspark.sql import SparkSession
3028
from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager
3129
from dve.core_engine.backends.implementations.spark.auditing import SparkAuditingManager
32-
from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations
3330
from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader
3431
from dve.pipeline.duckdb_pipeline import DDBDVEPipeline
3532
from dve.pipeline.spark_pipeline import SparkDVEPipeline
3633

3734
from utilities import (
38-
ERROR_DF_FIELDS,
3935
load_errors_from_service,
4036
get_test_file_path,
4137
SERVICE_TO_STORAGE_PATH_MAPPING,
@@ -74,8 +70,7 @@ def setup_duckdb_pipeline(
7470
connection: duckdb.DuckDBPyConnection,
7571
dataset_id: str,
7672
processing_path: Path,
77-
schema_file_name: Optional[str] = None,
78-
executor: Optional[ProcessPoolExecutor] = None
73+
schema_file_name: Optional[str] = None
7974
):
8075

8176
schema_file_name = f"{dataset_id}.dischema.json" if not schema_file_name else schema_file_name
@@ -97,8 +92,7 @@ def setup_duckdb_pipeline(
9792
connection=connection,
9893
rules_path=rules_path,
9994
submitted_files_path=processing_path.as_posix(),
100-
reference_data_loader=DuckDBRefDataLoader,
101-
executor=executor
95+
reference_data_loader=DuckDBRefDataLoader
10296
)
10397

10498

@@ -206,7 +200,7 @@ def add_pipeline_to_ctx(
206200
context: Context, implementation: str, schema_file_name: Optional[str] = None
207201
):
208202
pipeline_map: Dict[str, Callable] = {
209-
"duckdb": partial(setup_duckdb_pipeline, connection=context.connection, executor=context.process_pool),
203+
"duckdb": partial(setup_duckdb_pipeline, connection=context.connection),
210204
"spark": partial(setup_spark_pipeline, spark=context.spark_session),
211205
}
212206
if not implementation in pipeline_map:

tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from concurrent.futures import ProcessPoolExecutor
21
import json
3-
from multiprocessing import cpu_count
42
from pathlib import Path
53
from typing import Any, Dict, List, Tuple
64

@@ -32,13 +30,8 @@
3230
temp_xml_file,
3331
)
3432

35-
@pytest.fixture(scope="module")
36-
def temp_process_pool_executor():
37-
with ProcessPoolExecutor(cpu_count() - 1) as pool:
38-
yield pool
3933

40-
41-
def test_duckdb_data_contract_csv(temp_csv_file, temp_process_pool_executor):
34+
def test_duckdb_data_contract_csv(temp_csv_file):
4235
uri, _, _, mdl = temp_csv_file
4336
connection = default_connection
4437

@@ -97,7 +90,7 @@ def test_duckdb_data_contract_csv(temp_csv_file, temp_process_pool_executor):
9790
}
9891
entity_locations: Dict[str, URI] = {"test_ds": str(uri)}
9992

100-
data_contract: DuckDBDataContract = DuckDBDataContract(connection, executor=temp_process_pool_executor)
93+
data_contract: DuckDBDataContract = DuckDBDataContract(connection)
10194
entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta)
10295
rel: DuckDBPyRelation = entities.get("test_ds")
10396
assert dict(zip(rel.columns, rel.dtypes)) == {
@@ -108,7 +101,7 @@ def test_duckdb_data_contract_csv(temp_csv_file, temp_process_pool_executor):
108101
assert stage_successful
109102

110103

111-
def test_duckdb_data_contract_xml(temp_xml_file, temp_process_pool_executor):
104+
def test_duckdb_data_contract_xml(temp_xml_file):
112105
uri, header_model, header_data, class_model, class_data = temp_xml_file
113106
connection = default_connection
114107
contract_meta = json.dumps(
@@ -195,7 +188,7 @@ def test_duckdb_data_contract_xml(temp_xml_file, temp_process_pool_executor):
195188
reporting_fields={"test_header": ["school"], "test_class_info": ["year"]},
196189
)
197190

198-
data_contract: DuckDBDataContract = DuckDBDataContract(connection, executor=temp_process_pool_executor)
191+
data_contract: DuckDBDataContract = DuckDBDataContract(connection)
199192
entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta)
200193
header_rel: DuckDBPyRelation = entities.get("test_header")
201194
header_expected_schema: Dict[str, DuckDBPyType] = {
@@ -335,11 +328,10 @@ def test_ddb_data_contract_read_nested_parquet(nested_all_string_parquet):
335328
}
336329

337330
def test_duckdb_data_contract_custom_error_details(nested_all_string_parquet_w_errors,
338-
nested_parquet_custom_dc_err_details,
339-
temp_process_pool_executor):
331+
nested_parquet_custom_dc_err_details):
340332
parquet_uri, contract_meta, _ = nested_all_string_parquet_w_errors
341333
connection = default_connection
342-
data_contract = DuckDBDataContract(connection, executor=temp_process_pool_executor)
334+
data_contract = DuckDBDataContract(connection)
343335

344336
entity = data_contract.read_parquet(path=parquet_uri)
345337
assert entity.count("*").fetchone()[0] == 2

0 commit comments

Comments
 (0)