Skip to content

Commit c2ad557

Browse files
authored
Merge pull request #37 from NHSDigital/feature/ndit-655_refdata_arrow_integration
Add new duckdb data contract pyarrow batch approach and include arrow file loading in reference data
2 parents be0a630 + 85e9693 commit c2ad557

41 files changed

Lines changed: 1310 additions & 596 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ for entity in data_contract_config.schemas:
165165

166166
# Data contract step here
167167
data_contract = SparkDataContract(spark_session=spark)
168-
entities, validation_messages, success = data_contract.apply_data_contract(
169-
entities, data_contract_config
168+
entities, feedback_errors_uri, success = data_contract.apply_data_contract(
169+
entities, None, data_contract_config
170170
)
171171
```
172172

src/dve/common/__init__.py

Whitespace-only changes.

src/dve/common/error_utils.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""Utilities to support reporting"""
2+
3+
import datetime as dt
4+
import json
5+
import logging
6+
from collections.abc import Iterable
7+
from itertools import chain
8+
from multiprocessing import Queue
9+
from threading import Thread
10+
from typing import Optional, Union
11+
12+
import dve.parser.file_handling as fh
13+
from dve.core_engine.exceptions import CriticalProcessingError
14+
from dve.core_engine.loggers import get_logger
15+
from dve.core_engine.message import UserMessage
16+
from dve.core_engine.type_hints import URI, DVEStageName, Messages
17+
18+
19+
def get_feedback_errors_uri(working_folder: URI, step_name: DVEStageName) -> URI:
20+
"""Determine the location of json lines file containing all errors generated in a step"""
21+
return fh.joinuri(working_folder, "errors", f"{step_name}_errors.jsonl")
22+
23+
24+
def get_processing_errors_uri(working_folder: URI) -> URI:
25+
"""Determine the location of json lines file containing all processing
26+
errors generated from DVE run"""
27+
return fh.joinuri(working_folder, "processing_errors", "processing_errors.jsonl")
28+
29+
30+
def dump_feedback_errors(
31+
working_folder: URI,
32+
step_name: DVEStageName,
33+
messages: Messages,
34+
key_fields: Optional[dict[str, list[str]]] = None,
35+
) -> URI:
36+
"""Write out captured feedback error messages."""
37+
if not working_folder:
38+
raise AttributeError("processed files path not passed")
39+
40+
if not key_fields:
41+
key_fields = {}
42+
43+
error_file = get_feedback_errors_uri(working_folder, step_name)
44+
processed = []
45+
46+
for message in messages:
47+
if message.original_entity is not None:
48+
primary_keys = key_fields.get(message.original_entity, [])
49+
elif message.entity is not None:
50+
primary_keys = key_fields.get(message.entity, [])
51+
else:
52+
primary_keys = []
53+
54+
error = message.to_dict(
55+
key_field=primary_keys,
56+
value_separator=" -- ",
57+
max_number_of_values=10,
58+
record_converter=None,
59+
)
60+
error["Key"] = conditional_cast(error["Key"], primary_keys, value_separator=" -- ")
61+
processed.append(error)
62+
63+
with fh.open_stream(error_file, "a") as f:
64+
f.write("\n".join([json.dumps(rec, default=str) for rec in processed]) + "\n")
65+
return error_file
66+
67+
68+
def dump_processing_errors(
69+
working_folder: URI, step_name: str, errors: list[CriticalProcessingError]
70+
):
71+
"""Write out critical processing errors"""
72+
if not working_folder:
73+
raise AttributeError("processed files path not passed")
74+
if not step_name:
75+
raise AttributeError("step name not passed")
76+
if not errors:
77+
raise AttributeError("errors list not passed")
78+
79+
error_file: URI = get_processing_errors_uri(working_folder)
80+
processed = []
81+
82+
for error in errors:
83+
processed.append(
84+
{
85+
"step_name": step_name,
86+
"error_location": "processing",
87+
"error_level": "integrity",
88+
"error_message": error.error_message,
89+
"error_traceback": error.messages,
90+
}
91+
)
92+
93+
with fh.open_stream(error_file, "a") as f:
94+
f.write("\n".join([json.dumps(rec, default=str) for rec in processed]) + "\n")
95+
96+
return error_file
97+
98+
99+
def load_feedback_messages(feedback_messages_uri: URI) -> Iterable[UserMessage]:
100+
"""Load user messages from jsonl file"""
101+
if not fh.get_resource_exists(feedback_messages_uri):
102+
return
103+
with fh.open_stream(feedback_messages_uri) as errs:
104+
yield from (UserMessage(**json.loads(err)) for err in errs.readlines())
105+
106+
107+
def load_all_error_messages(error_directory_uri: URI) -> Iterable[UserMessage]:
108+
"Load user messages from all jsonl files"
109+
return chain.from_iterable(
110+
[
111+
load_feedback_messages(err_file)
112+
for err_file, _ in fh.iter_prefix(error_directory_uri)
113+
if err_file.endswith(".jsonl")
114+
]
115+
)
116+
117+
118+
class BackgroundMessageWriter:
119+
"""Controls batch writes to error jsonl files"""
120+
121+
def __init__(
122+
self,
123+
working_directory: URI,
124+
dve_stage: DVEStageName,
125+
key_fields: Optional[dict[str, list[str]]] = None,
126+
logger: Optional[logging.Logger] = None,
127+
):
128+
self._working_directory = working_directory
129+
self._dve_stage = dve_stage
130+
self._feedback_message_uri = get_feedback_errors_uri(
131+
self._working_directory, self._dve_stage
132+
)
133+
self._key_fields = key_fields
134+
self.logger = logger or get_logger(type(self).__name__)
135+
self._write_thread: Optional[Thread] = None
136+
self._queue: Queue = Queue()
137+
138+
@property
139+
def write_queue(self) -> Queue: # type: ignore
140+
"""Queue for storing batches of messages to be written"""
141+
return self._queue
142+
143+
@property
144+
def write_thread(self) -> Thread: # type: ignore
145+
"""Thread to write batches of messages to jsonl file"""
146+
if not self._write_thread:
147+
self._write_thread = Thread(target=self._write_process_wrapper)
148+
return self._write_thread
149+
150+
def _write_process_wrapper(self):
151+
"""Wrapper for dump feedback errors to run in background process"""
152+
# writing thread will block if nothing in queue
153+
while True:
154+
if msgs := self.write_queue.get():
155+
dump_feedback_errors(
156+
self._working_directory, self._dve_stage, msgs, self._key_fields
157+
)
158+
else:
159+
break
160+
161+
def __enter__(self) -> "BackgroundMessageWriter":
162+
self.write_thread.start()
163+
return self
164+
165+
def __exit__(self, exc_type, exc_value, traceback):
166+
if exc_type:
167+
self.logger.exception(
168+
"Issue occured during background write process:",
169+
exc_info=(exc_type, exc_value, traceback),
170+
)
171+
# None value in queue will trigger break in target
172+
self.write_queue.put(None)
173+
self.write_thread.join()
174+
175+
176+
def conditional_cast(value, primary_keys: list[str], value_separator: str) -> Union[list[str], str]:
177+
"""Determines what to do with a value coming back from the error list"""
178+
if isinstance(value, list):
179+
casts = [
180+
conditional_cast(val, primary_keys, value_separator) for val in value
181+
] # type: ignore
182+
return value_separator.join(
183+
[f"{pk}: {id}" if pk else "" for pk, id in zip(primary_keys, casts)]
184+
)
185+
if isinstance(value, dt.date):
186+
return value.isoformat()
187+
if isinstance(value, dict):
188+
return ""
189+
return str(value)

src/dve/core_engine/backends/base/backend.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,8 @@
1717
from dve.core_engine.backends.types import Entities, EntityType, StageSuccessful
1818
from dve.core_engine.loggers import get_logger
1919
from dve.core_engine.models import SubmissionInfo
20-
from dve.core_engine.type_hints import (
21-
URI,
22-
EntityLocations,
23-
EntityName,
24-
EntityParquetLocations,
25-
Messages,
26-
)
20+
from dve.core_engine.type_hints import URI, EntityLocations, EntityName, EntityParquetLocations
21+
from dve.parser.file_handling.service import get_parent, joinuri
2722

2823

2924
class BaseBackend(Generic[EntityType], ABC):
@@ -148,65 +143,71 @@ def convert_entities_to_spark(
148143

149144
def apply(
150145
self,
146+
working_dir: URI,
151147
entity_locations: EntityLocations,
152148
contract_metadata: DataContractMetadata,
153149
rule_metadata: RuleMetadata,
154150
submission_info: Optional[SubmissionInfo] = None,
155-
) -> tuple[Entities, Messages, StageSuccessful]:
151+
) -> tuple[Entities, URI, StageSuccessful]:
156152
"""Apply the data contract and the rules, returning the entities and all
157153
generated messages.
158154
159155
"""
160156
reference_data = self.load_reference_data(
161157
rule_metadata.reference_data_config, submission_info
162158
)
163-
entities, messages, successful = self.contract.apply(entity_locations, contract_metadata)
159+
entities, dc_feedback_errors_uri, successful, processing_errors_uri = self.contract.apply(
160+
working_dir, entity_locations, contract_metadata
161+
)
164162
if not successful:
165-
return entities, messages, successful
163+
return entities, get_parent(processing_errors_uri), successful
166164

167165
for entity_name, entity in entities.items():
168166
entities[entity_name] = self.step_implementations.add_row_id(entity)
169167

170168
# TODO: Handle entity manager creation errors.
171169
entity_manager = EntityManager(entities, reference_data)
172170
# TODO: Add stage success to 'apply_rules'
173-
rule_messages = self.step_implementations.apply_rules(entity_manager, rule_metadata)
174-
messages.extend(rule_messages)
171+
# TODO: In case of large errors in business rules, write messages to jsonl file
172+
# TODO: and return uri to errors
173+
_ = self.step_implementations.apply_rules(working_dir, entity_manager, rule_metadata)
175174

176175
for entity_name, entity in entity_manager.entities.items():
177176
entity_manager.entities[entity_name] = self.step_implementations.drop_row_id(entity)
178177

179-
return entity_manager.entities, messages, True
178+
return entity_manager.entities, get_parent(dc_feedback_errors_uri), True
180179

181180
def process(
182181
self,
182+
working_dir: URI,
183183
entity_locations: EntityLocations,
184184
contract_metadata: DataContractMetadata,
185185
rule_metadata: RuleMetadata,
186-
cache_prefix: URI,
187186
submission_info: Optional[SubmissionInfo] = None,
188-
) -> tuple[MutableMapping[EntityName, URI], Messages]:
187+
) -> tuple[MutableMapping[EntityName, URI], URI]:
189188
"""Apply the data contract and the rules, write the entities out to parquet
190189
and returning the entity locations and all generated messages.
191190
192191
"""
193-
entities, messages, successful = self.apply(
194-
entity_locations, contract_metadata, rule_metadata, submission_info
192+
entities, feedback_errors_uri, successful = self.apply(
193+
working_dir, entity_locations, contract_metadata, rule_metadata, submission_info
195194
)
196195
if successful:
197-
parquet_locations = self.write_entities_to_parquet(entities, cache_prefix)
196+
parquet_locations = self.write_entities_to_parquet(
197+
entities, joinuri(working_dir, "outputs")
198+
)
198199
else:
199200
parquet_locations = {}
200-
return parquet_locations, messages
201+
return parquet_locations, get_parent(feedback_errors_uri)
201202

202203
def process_legacy(
203204
self,
205+
working_dir: URI,
204206
entity_locations: EntityLocations,
205207
contract_metadata: DataContractMetadata,
206208
rule_metadata: RuleMetadata,
207-
cache_prefix: URI,
208209
submission_info: Optional[SubmissionInfo] = None,
209-
) -> tuple[MutableMapping[EntityName, DataFrame], Messages]:
210+
) -> tuple[MutableMapping[EntityName, DataFrame], URI]:
210211
"""Apply the data contract and the rules, create Spark `DataFrame`s from the
211212
entities and return the Spark entities and all generated messages.
212213
@@ -221,17 +222,19 @@ def process_legacy(
221222
category=DeprecationWarning,
222223
)
223224

224-
entities, messages, successful = self.apply(
225-
entity_locations, contract_metadata, rule_metadata, submission_info
225+
entities, errors_uri, successful = self.apply(
226+
working_dir, entity_locations, contract_metadata, rule_metadata, submission_info
226227
)
227228

228229
if not successful:
229-
return {}, messages
230+
return {}, errors_uri
230231

231232
if self.__entity_type__ == DataFrame:
232-
return entities, messages # type: ignore
233+
return entities, errors_uri # type: ignore
233234

234235
return (
235-
self.convert_entities_to_spark(entities, cache_prefix, _emit_deprecation_warning=False),
236-
messages,
236+
self.convert_entities_to_spark(
237+
entities, joinuri(working_dir, "outputs"), _emit_deprecation_warning=False
238+
),
239+
errors_uri,
237240
)

0 commit comments

Comments
 (0)