diff --git a/.github/workflows/ci_energyml_utils_pull_request.yml b/.github/workflows/ci_energyml_utils_pull_request.yml index 8903539..50380a7 100644 --- a/.github/workflows/ci_energyml_utils_pull_request.yml +++ b/.github/workflows/ci_energyml_utils_pull_request.yml @@ -3,7 +3,7 @@ ## SPDX-License-Identifier: Apache-2.0 ## --- -name: Publish (pypiTest) +name: Test/Build/Publish (pypiTest) defaults: run: @@ -18,8 +18,31 @@ on: types: [published] jobs: + test: + name: Run tests + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install poetry + uses: ./.github/actions/prepare-poetry + with: + python-version: "3.10" + + - name: Install dependencies + run: | + poetry install + + - name: Run pytest + run: | + poetry run pytest -v --tb=short + build: name: Build distribution + needs: [test] runs-on: ubuntu-latest steps: - name: Checkout code diff --git a/energyml-utils/.flake8 b/energyml-utils/.flake8 index 07de32c..4830dae 100644 --- a/energyml-utils/.flake8 +++ b/energyml-utils/.flake8 @@ -1,6 +1,6 @@ [flake8] # Ignore specific error codes (comma-separated list) -ignore = E501, E722, W503, F403, E203, E202 +ignore = E501, E722, W503, F403, E203, E202, E402 # Max line length (default is 79, can be changed) max-line-length = 120 diff --git a/energyml-utils/.gitignore b/energyml-utils/.gitignore index 38a850f..f672e3c 100644 --- a/energyml-utils/.gitignore +++ b/energyml-utils/.gitignore @@ -44,6 +44,7 @@ sample/ gen*/ manip* *.epc +*.h5 *.off *.obj *.log @@ -54,8 +55,16 @@ manip* *.xml *.json +docs/*.md + +# DATA +*.obj +*.geojson +*.vtk +*.stl # WIP src/energyml/utils/wip* -scripts \ No newline at end of file +scripts +rc/camunda \ No newline at end of file diff --git a/energyml-utils/example/epc_stream_keep_open_example.py b/energyml-utils/example/epc_stream_keep_open_example.py new file mode 100644 index 0000000..ea9d9cc --- /dev/null +++ b/energyml-utils/example/epc_stream_keep_open_example.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python +# Copyright (c) 2023-2024 Geosiris. +# SPDX-License-Identifier: Apache-2.0 +""" +Example demonstrating the keep_open feature of EpcStreamReader. + +This example shows how using keep_open=True improves performance when +performing multiple operations on an EPC file by keeping the ZIP file +open instead of reopening it for each operation. +""" + +import time +import sys +from pathlib import Path + +# Add src directory to path +src_path = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(src_path)) + +from energyml.utils.epc_stream import EpcStreamReader + + +def benchmark_without_keep_open(epc_path: str, num_operations: int = 10): + """Benchmark reading objects without keep_open.""" + print(f"\nBenchmark WITHOUT keep_open ({num_operations} operations):") + print("=" * 60) + + start = time.time() + + # Create reader without keep_open + with EpcStreamReader(epc_path, keep_open=False, cache_size=5) as reader: + metadata_list = reader.list_object_metadata() + + if not metadata_list: + print(" No objects in EPC file") + return 0 + + # Perform multiple read operations + for i in range(min(num_operations, len(metadata_list))): + meta = metadata_list[i % len(metadata_list)] + if meta.identifier: + _ = reader.get_object_by_identifier(meta.identifier) + if i == 0: + print(f" First object: {meta.object_type}") + + elapsed = time.time() - start + print(f" Time: {elapsed:.4f}s") + print(f" Avg per operation: {elapsed / num_operations:.4f}s") + + return elapsed + + +def benchmark_with_keep_open(epc_path: str, num_operations: int = 10): + """Benchmark reading objects with keep_open.""" + print(f"\nBenchmark WITH keep_open ({num_operations} operations):") + print("=" * 60) + + start = time.time() + + # Create reader with keep_open + with EpcStreamReader(epc_path, keep_open=True, cache_size=5) as reader: + metadata_list = reader.list_object_metadata() + + if not metadata_list: + print(" No objects in EPC file") + return 0 + + # Perform multiple read operations + for i in range(min(num_operations, len(metadata_list))): + meta = metadata_list[i % len(metadata_list)] + if meta.identifier: + _ = reader.get_object_by_identifier(meta.identifier) + if i == 0: + print(f" First object: {meta.object_type}") + + elapsed = time.time() - start + print(f" Time: {elapsed:.4f}s") + print(f" Avg per operation: {elapsed / num_operations:.4f}s") + + return elapsed + + +def demonstrate_file_modification_with_keep_open(epc_path: str): + """Demonstrate that modifications work correctly with keep_open.""" + print("\nDemonstrating file modifications with keep_open:") + print("=" * 60) + + with EpcStreamReader(epc_path, keep_open=True) as reader: + metadata_list = reader.list_object_metadata() + original_count = len(metadata_list) + print(f" Original object count: {original_count}") + + if metadata_list: + # Get first object + first_obj = reader.get_object_by_identifier(metadata_list[0].identifier) + print(f" Retrieved object: {metadata_list[0].object_type}") + + # Update the object (re-add it) + identifier = reader.update_object(first_obj) + print(f" Updated object: {identifier}") + + # Verify we can still read it after update + updated_obj = reader.get_object_by_identifier(identifier) + assert updated_obj is not None, "Failed to read object after update" + print(" ✓ Object successfully read after update") + + # Verify object count is the same + new_metadata_list = reader.list_object_metadata() + new_count = len(new_metadata_list) + print(f" New object count: {new_count}") + + if new_count == original_count: + print(" ✓ Object count unchanged (correct)") + else: + print(f" ✗ Object count changed: {original_count} -> {new_count}") + + +def demonstrate_proper_cleanup(): + """Demonstrate that persistent ZIP file is properly closed.""" + print("\nDemonstrating proper cleanup:") + print("=" * 60) + + temp_path = "temp_test.epc" + + try: + # Create a temporary EPC file + reader = EpcStreamReader(temp_path, keep_open=True) + print(" Created EpcStreamReader with keep_open=True") + + # Manually close + reader.close() + print(" ✓ Manually closed reader") + + # Create another reader and let it go out of scope + reader2 = EpcStreamReader(temp_path, keep_open=True) + print(" Created second EpcStreamReader") + del reader2 + print(" ✓ Reader deleted (automatic cleanup via __del__)") + + # Create reader in context manager + with EpcStreamReader(temp_path, keep_open=True) as _: + print(" Created third EpcStreamReader in context manager") + print(" ✓ Context manager exited (automatic cleanup)") + + finally: + # Clean up temp file + if Path(temp_path).exists(): + Path(temp_path).unlink() + + +def main(): + """Run all examples.""" + print("EpcStreamReader keep_open Feature Demonstration") + print("=" * 60) + + # You'll need to provide a valid EPC file path + epc_path = "wip/epc_test.epc" + + if not Path(epc_path).exists(): + print(f"\nError: EPC file not found: {epc_path}") + print("Please provide a valid EPC file path in the script.") + print("\nRunning cleanup demonstration only:") + demonstrate_proper_cleanup() + return + + try: + # Run benchmarks + num_ops = 20 + + time_without = benchmark_without_keep_open(epc_path, num_ops) + time_with = benchmark_with_keep_open(epc_path, num_ops) + + # Show comparison + print("\n" + "=" * 60) + print("Performance Comparison:") + print("=" * 60) + if time_with > 0 and time_without > 0: + speedup = time_without / time_with + improvement = ((time_without - time_with) / time_without) * 100 + print(f" Speedup: {speedup:.2f}x") + print(f" Improvement: {improvement:.1f}%") + + if speedup > 1.1: + print("\n ✓ keep_open=True significantly improves performance!") + elif speedup > 1.0: + print("\n ✓ keep_open=True slightly improves performance") + else: + print("\n Note: For this workload, the difference is minimal") + print(" (cache effects or small file)") + + # Demonstrate modifications + demonstrate_file_modification_with_keep_open(epc_path) + + # Demonstrate cleanup + demonstrate_proper_cleanup() + + print("\n" + "=" * 60) + print("All demonstrations completed successfully!") + print("=" * 60) + + except Exception as e: + print(f"\nError: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/energyml-utils/example/main.py b/energyml-utils/example/main.py index 6301e7c..4313ed5 100644 --- a/energyml-utils/example/main.py +++ b/energyml-utils/example/main.py @@ -1,14 +1,27 @@ # Copyright (c) 2023-2024 Geosiris. # SPDX-License-Identifier: Apache-2.0 import sys +import logging from pathlib import Path import re from dataclasses import fields +from energyml.utils.constants import ( + RGX_CONTENT_TYPE, + EpcExportVersion, + date_to_epoch, + epoch, + epoch_to_date, + gen_uuid, + get_domain_version_from_content_or_qualified_type, + parse_content_or_qualified_type, + parse_content_type, +) + src_path = Path(__file__).parent.parent / "src" sys.path.insert(0, str(src_path)) -from energyml.eml.v2_3.commonv2 import * +from energyml.eml.v2_3.commonv2 import Citation, DataObjectReference, ExistenceKind, Activity from energyml.eml.v2_3.commonv2 import AbstractObject from energyml.resqml.v2_0_1.resqmlv2 import DoubleHdf5Array from energyml.resqml.v2_0_1.resqmlv2 import TriangulatedSetRepresentation as Tr20 @@ -22,17 +35,70 @@ # from src.energyml.utils.data.hdf import * from energyml.utils.data.helper import get_projected_uom, is_z_reversed -from energyml.utils.epc import * -from energyml.utils.introspection import * -from energyml.utils.manager import * -from energyml.utils.serialization import * +from energyml.utils.epc import ( + Epc, + EPCRelsRelationshipType, + as_dor, + create_energyml_object, + create_external_part_reference, + gen_energyml_object_path, + get_reverse_dor_list, +) +from energyml.utils.introspection import ( + class_match_rgx, + copy_attributes, + get_class_attributes, + get_class_fields, + get_class_from_content_type, + get_class_from_name, + get_class_from_qualified_type, + get_class_methods, + get_content_type_from_class, + get_obj_pkg_pkgv_type_uuid_version, + get_obj_uri, + get_object_attribute, + get_obj_uuid, + get_object_attribute_rgx, + get_qualified_type_from_class, + is_abstract, + is_primitive, + random_value_from_class, + search_attribute_matching_name, + search_attribute_matching_name_with_path, + search_attribute_matching_type, + search_attribute_matching_type_with_path, +) +from energyml.utils.manager import ( + # create_energyml_object, + # create_external_part_reference, + dict_energyml_modules, + get_class_pkg, + get_class_pkg_version, + get_classes_matching_name, + get_sub_classes, + list_energyml_modules, +) +from energyml.utils.serialization import ( + read_energyml_xml_file, + read_energyml_xml_str, + serialize_json, + JSON_VERSION, + serialize_xml, +) from energyml.utils.validation import ( patterns_validation, dor_validation, validate_epc, correct_dor, ) -from energyml.utils.xml import * +from energyml.utils.xml import ( + find_schema_version_in_element, + get_class_name_from_xml, + get_root_namespace, + get_root_type, + get_tree, + get_xml_encoding, +) from energyml.utils.data.datasets_io import HDF5FileReader, get_path_in_external_with_path fi_cit = Citation( diff --git a/energyml-utils/example/main_data.py b/energyml-utils/example/main_data.py index a05cd20..52ff8ee 100644 --- a/energyml-utils/example/main_data.py +++ b/energyml-utils/example/main_data.py @@ -1,6 +1,7 @@ # Copyright (c) 2023-2024 Geosiris. # SPDX-License-Identifier: Apache-2.0 - +import logging +from io import BytesIO from energyml.eml.v2_3.commonv2 import ( JaggedArray, AbstractValueArray, @@ -8,16 +9,27 @@ StringXmlArray, IntegerXmlArray, ) +from energyml.utils.data.export import export_obj from src.energyml.utils.data.helper import ( get_array_reader_function, + read_array, +) +from src.energyml.utils.data.mesh import ( + GeoJsonGeometryType, + MeshFileFormat, + _create_shape, + _write_geojson_shape, + export_multiple_data, + export_off, + read_mesh_object, ) -from src.energyml.utils.data.mesh import * -from src.energyml.utils.data.mesh import _create_shape, _write_geojson_shape from src.energyml.utils.epc import gen_energyml_object_path from src.energyml.utils.introspection import ( + get_object_attribute, is_abstract, get_obj_uuid, + search_attribute_matching_name_with_path, ) from src.energyml.utils.manager import get_sub_classes from src.energyml.utils.serialization import ( @@ -28,11 +40,17 @@ ) from src.energyml.utils.validation import validate_epc from src.energyml.utils.xml import get_tree -from utils.data.datasets_io import ( +from src.energyml.utils.data.datasets_io import ( HDF5FileReader, get_path_in_external_with_path, get_external_file_path_from_external_path, ) +from energyml.utils.epc import Epc +from src.energyml.utils.data.mesh import ( + read_polyline_representation, + read_point_representation, + read_grid2d_representation, +) logger = logging.getLogger(__name__) @@ -607,7 +625,7 @@ def test_simple_geojson(): ), ) - print(f"\n+++++++++++++++++++++++++\n") + print("\n+++++++++++++++++++++++++\n") def test_simple_geojson_io(): diff --git a/energyml-utils/example/main_datasets.py b/energyml-utils/example/main_datasets.py index edc1278..234ed43 100644 --- a/energyml-utils/example/main_datasets.py +++ b/energyml-utils/example/main_datasets.py @@ -1,15 +1,15 @@ # Copyright (c) 2023-2024 Geosiris. # SPDX-License-Identifier: Apache-2.0 -from src.energyml.utils.data.datasets_io import ( +from energyml.utils.data.datasets_io import ( ParquetFileReader, ParquetFileWriter, CSVFileReader, CSVFileWriter, read_dataset, ) -from utils.data.helper import read_array -from utils.introspection import search_attribute_matching_name_with_path -from utils.serialization import read_energyml_xml_file +from energyml.utils.data.helper import read_array +from energyml.utils.introspection import search_attribute_matching_name_with_path +from energyml.utils.serialization import read_energyml_xml_file def local_parquet(): diff --git a/energyml-utils/example/main_hdf.py b/energyml-utils/example/main_hdf.py deleted file mode 100644 index ac23ed4..0000000 --- a/energyml-utils/example/main_hdf.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2023-2024 Geosiris. -# SPDX-License-Identifier: Apache-2.0 -import sys -from pathlib import Path - -# Add src directory to path -src_path = Path(__file__).parent.parent / "src" -sys.path.insert(0, str(src_path)) - -from energyml.utils.data.datasets_io import get_path_in_external_with_path -from energyml.utils.introspection import get_obj_uri - - -if __name__ == "__main__": - from energyml.utils.epc import Epc - - # Create an EPC file - epc = Epc.read_file("wip/BRGM_AVRE_all_march_25.epc") - - print("\n".join(map(lambda o: str(get_obj_uri(o)), epc.energyml_objects))) - - print(epc.get_h5_file_paths("eml:///resqml22.PolylineSetRepresentation(e75db94d-a251-4f31-8a24-23b9573fbf39)")) - - print( - get_path_in_external_with_path( - epc.get_object_by_identifier( - "eml:///resqml22.PolylineSetRepresentation(e75db94d-a251-4f31-8a24-23b9573fbf39)" - ) - ) - ) - - print( - epc.read_h5_dataset( - "eml:///resqml22.PolylineSetRepresentation(e75db94d-a251-4f31-8a24-23b9573fbf39)", - "/RESQML/e75db94d-a251-4f31-8a24-23b9573fbf39/points_patch0", - ) - ) diff --git a/energyml-utils/example/main_stream.py b/energyml-utils/example/main_stream.py index b1a712a..87f529a 100644 --- a/energyml-utils/example/main_stream.py +++ b/energyml-utils/example/main_stream.py @@ -24,12 +24,13 @@ from energyml.utils.serialization import serialize_json +from energyml.resqml.v2_2.resqmlv2 import TriangulatedSetRepresentation, ContactElement +from energyml.eml.v2_3.commonv2 import DataObjectReference + + def test_epc_stream_main(): logging.basicConfig(level=logging.DEBUG) - from energyml.resqml.v2_2.resqmlv2 import TriangulatedSetRepresentation, ContactElement - from energyml.eml.v2_3.commonv2 import DataObjectReference - # Use the test EPC file test_epc = "wip/my_stream_file.epc" @@ -115,9 +116,6 @@ def test_epc_stream_main(): def test_epc_im_main(): logging.basicConfig(level=logging.DEBUG) - from energyml.resqml.v2_2.resqmlv2 import TriangulatedSetRepresentation, ContactElement - from energyml.eml.v2_3.commonv2 import DataObjectReference - # Use the test EPC file test_epc = "wip/my_stream_file.epc" diff --git a/energyml-utils/example/main_test_3D.py b/energyml-utils/example/main_test_3D.py new file mode 100644 index 0000000..0657bdf --- /dev/null +++ b/energyml-utils/example/main_test_3D.py @@ -0,0 +1,145 @@ +# Copyright (c) 2023-2024 Geosiris. +# SPDX-License-Identifier: Apache-2.0 +import os +import re +import datetime +from pathlib import Path +import traceback +from typing import Optional + +from energyml.utils.data.export import export_obj, export_stl, export_vtk +from energyml.utils.data.mesh import read_mesh_object +from energyml.utils.epc_stream import EpcStreamReader +from energyml.utils.epc import Epc + +from energyml.utils.exception import NotSupportedError + + +def export_all_representation(epc_path: str, output_dir: str, regex_type_filter: Optional[str] = None): + + storage = EpcStreamReader(epc_path, keep_open=True) + + dt = datetime.datetime.now().strftime("%Hh%M_%d-%m-%Y") + not_supported_types = set() + for mdata in storage.list_objects(): + if "Representation" in mdata.object_type and ( + regex_type_filter is None + or len(regex_type_filter) == 0 + or re.search(regex_type_filter, mdata.object_type, flags=re.IGNORECASE) + ): + logging.info(f"Exporting representation: {mdata.object_type} ({mdata.uuid})") + energyml_obj = storage.get_object_by_uuid(mdata.uuid)[0] + try: + mesh_list = read_mesh_object( + energyml_object=energyml_obj, + workspace=storage, + use_crs_displacement=True, + ) + + os.makedirs(output_dir, exist_ok=True) + + path = Path(output_dir) / f"{dt}-{mdata.object_type}{mdata.uuid}_mesh.obj" + with path.open("wb") as f: + export_obj( + mesh_list=mesh_list, + out=f, + ) + export_stl_path = path.with_suffix(".stl") + with export_stl_path.open("wb") as stl_f: + export_stl( + mesh_list=mesh_list, + out=stl_f, + ) + export_vtk_path = path.with_suffix(".vtk") + with export_vtk_path.open("wb") as vtk_f: + export_vtk( + mesh_list=mesh_list, + out=vtk_f, + ) + + logging.info(f" ✓ Exported to {path.name}") + except NotSupportedError: + # print(f" ✗ Not supported: {e}") + not_supported_types.add(mdata.object_type) + except Exception: + traceback.print_exc() + + logging.info("Export completed.") + if not_supported_types: + logging.info("Not supported representation types encountered:") + for t in not_supported_types: + logging.info(f" - {t}") + + +def export_all_representation_in_memory(epc_path: str, output_dir: str, regex_type_filter: Optional[str] = None): + + storage = Epc.read_file(epc_path) + if storage is None: + logging.error(f"Failed to read EPC file: {epc_path}") + return + + dt = datetime.datetime.now().strftime("%Hh%M_%d-%m-%Y") + not_supported_types = set() + for mdata in storage.list_objects(): + if "Representation" in mdata.object_type and ( + regex_type_filter is None + or len(regex_type_filter) == 0 + or re.search(regex_type_filter, mdata.object_type, flags=re.IGNORECASE) + ): + logging.info(f"Exporting representation: {mdata.object_type} ({mdata.uuid})") + energyml_obj = storage.get_object_by_uuid(mdata.uuid)[0] + try: + mesh_list = read_mesh_object( + energyml_object=energyml_obj, + workspace=storage, + use_crs_displacement=True, + ) + + os.makedirs(output_dir, exist_ok=True) + + path = Path(output_dir) / f"{dt}-{mdata.object_type}{mdata.uuid}_mesh.obj" + with path.open("wb") as f: + export_obj( + mesh_list=mesh_list, + out=f, + ) + export_stl_path = path.with_suffix(".stl") + with export_stl_path.open("wb") as stl_f: + export_stl( + mesh_list=mesh_list, + out=stl_f, + ) + export_vtk_path = path.with_suffix(".vtk") + with export_vtk_path.open("wb") as vtk_f: + export_vtk( + mesh_list=mesh_list, + out=vtk_f, + ) + + logging.info(f" ✓ Exported to {path.name}") + except NotSupportedError: + # print(f" ✗ Not supported: {e}") + not_supported_types.add(mdata.object_type) + except Exception: + traceback.print_exc() + + logging.info("Export completed.") + if not_supported_types: + logging.info("Not supported representation types encountered:") + for t in not_supported_types: + logging.info(f" - {t}") + + +# $env:PYTHONPATH="$(pwd)\src"; poetry run python example/main_test_3D.py +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.DEBUG) + # epc_file = "rc/epc/testingPackageCpp.epc" + epc_file = "rc/epc/output-val.epc" + # epc_file = "rc/epc/Volve_Horizons_and_Faults_Depth_originEQN.epc" + output_directory = Path("exported_meshes") / Path(epc_file).name.replace(".epc", "_3D_export") + # export_all_representation(epc_file, output_directory) + # export_all_representation(epc_file, output_directory, regex_type_filter="Wellbore") + # export_all_representation(epc_file, str(output_directory), regex_type_filter="") + export_all_representation_in_memory(epc_file, str(output_directory), regex_type_filter="") diff --git a/energyml-utils/example/tools.py b/energyml-utils/example/tools.py index 3c889ba..20dfe69 100644 --- a/energyml-utils/example/tools.py +++ b/energyml-utils/example/tools.py @@ -291,7 +291,7 @@ def generate_data(): "-ff", type=str, default="json", - help=f"Type of the output files (one of : ['json', 'xml']). Default is 'json'", + help="Type of the output files (one of : ['json', 'xml']). Default is 'json'", ) args = parser.parse_args() @@ -413,7 +413,7 @@ def xml_to_json(): def json_to_xml(): parser = argparse.ArgumentParser() parser.add_argument("--file", "-f", type=str, help="Input File") - parser.add_argument("--out", "-o", type=str, default=None, help=f"Output file") + parser.add_argument("--out", "-o", type=str, default=None, help="Output file") args = parser.parse_args() @@ -436,7 +436,7 @@ def json_to_xml(): def json_to_epc(): parser = argparse.ArgumentParser() parser.add_argument("--file", "-f", type=str, help="Input File") - parser.add_argument("--out", "-o", type=str, default=None, help=f"Output EPC file") + parser.add_argument("--out", "-o", type=str, default=None, help="Output EPC file") args = parser.parse_args() diff --git a/energyml-utils/pyproject.toml b/energyml-utils/pyproject.toml index a3ff9a8..4ce977f 100644 --- a/energyml-utils/pyproject.toml +++ b/energyml-utils/pyproject.toml @@ -48,6 +48,10 @@ include = [ [tool.pytest.ini_options] pythonpath = [ "src" ] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", +] +addopts = "-m 'not slow'" testpaths = [ "tests" ] python_files = [ "test_*.py", "*_test.py" ] python_classes = [ "Test*" ] @@ -76,6 +80,7 @@ black = "^22.3.0" pylint = "^2.7.2" click = ">=8.1.3, <=8.1.3" # upper version than 8.0.2 fail with black pdoc3 = "^0.10.0" +pydantic = { version = "^2.0", optional = true } energyml-common2-0 = "^1.12.0" energyml-common2-1 = "^1.12.0" energyml-common2-2 = "^1.12.0" diff --git a/energyml-utils/src/energyml/utils/constants.py b/energyml-utils/src/energyml/utils/constants.py index f2e13d8..5735660 100644 --- a/energyml-utils/src/energyml/utils/constants.py +++ b/energyml-utils/src/energyml/utils/constants.py @@ -427,7 +427,10 @@ def epoch(time_zone=datetime.timezone.utc) -> int: def date_to_epoch(date: str) -> int: """Convert energyml date string to epoch timestamp""" try: - return int(datetime.datetime.fromisoformat(date).timestamp()) + # Python 3.10 doesn't support 'Z' suffix in fromisoformat() + # Replace 'Z' with '+00:00' for compatibility + date_normalized = date.replace("Z", "+00:00") if date.endswith("Z") else date + return int(datetime.datetime.fromisoformat(date_normalized).timestamp()) except (ValueError, TypeError): raise ValueError(f"Invalid date format: {date}") diff --git a/energyml-utils/src/energyml/utils/data/datasets_io.py b/energyml-utils/src/energyml/utils/data/datasets_io.py index 3325eeb..d899015 100644 --- a/energyml-utils/src/energyml/utils/data/datasets_io.py +++ b/energyml-utils/src/energyml/utils/data/datasets_io.py @@ -54,61 +54,98 @@ # HDF5 if __H5PY_MODULE_EXISTS__: - def h5_list_datasets(h5_file_path: Union[BytesIO, str]) -> List[str]: + def h5_list_datasets(h5_file_path: Union[BytesIO, str, "h5py.File"]) -> List[str]: """ List all datasets in an HDF5 file. - :param h5_file_path: Path to the HDF5 file + :param h5_file_path: Path to the HDF5 file, BytesIO object, or an already opened h5py.File :return: List of dataset names in the HDF5 file """ res = [] - with h5py.File(h5_file_path, "r") as f: # type: ignore - # Function to print the names of all datasets + + # Check if it's already an opened h5py.File + if isinstance(h5_file_path, h5py.File): # type: ignore + def list_datasets(name, obj): - if isinstance(obj, h5py.Dataset): # Check if the object is a dataset # type: ignore + if isinstance(obj, h5py.Dataset): # type: ignore res.append(name) - # Visit all items in the HDF5 file and apply the list function - f.visititems(list_datasets) + h5_file_path.visititems(list_datasets) + else: + with h5py.File(h5_file_path, "r") as f: # type: ignore + # Function to print the names of all datasets + def list_datasets(name, obj): + if isinstance(obj, h5py.Dataset): # Check if the object is a dataset # type: ignore + res.append(name) + + # Visit all items in the HDF5 file and apply the list function + f.visititems(list_datasets) return res @dataclass class HDF5FileReader(DatasetReader): # noqa: F401 - def read_array(self, source: Union[BytesIO, str], path_in_external_file: str) -> Optional[np.ndarray]: - with h5py.File(source, "r") as f: # type: ignore - d_group = f[path_in_external_file] + def read_array( + self, source: Union[BytesIO, str, "h5py.File"], path_in_external_file: str + ) -> Optional[np.ndarray]: + # Check if it's already an opened h5py.File + if isinstance(source, h5py.File): # type: ignore + d_group = source[path_in_external_file] return d_group[()] # type: ignore - - def get_array_dimension(self, source: Union[BytesIO, str], path_in_external_file: str) -> Optional[List[int]]: - with h5py.File(source, "r") as f: # type: ignore - return list(f[path_in_external_file].shape) + else: + with h5py.File(source, "r") as f: # type: ignore + d_group = f[path_in_external_file] + return d_group[()] # type: ignore + + def get_array_dimension( + self, source: Union[BytesIO, str, "h5py.File"], path_in_external_file: str + ) -> Optional[List[int]]: + # Check if it's already an opened h5py.File + if isinstance(source, h5py.File): # type: ignore + return list(source[path_in_external_file].shape) + else: + with h5py.File(source, "r") as f: # type: ignore + return list(f[path_in_external_file].shape) def extract_h5_datasets( self, - input_h5: Union[BytesIO, str], - output_h5: Union[BytesIO, str], + input_h5: Union[BytesIO, str, "h5py.File"], + output_h5: Union[BytesIO, str, "h5py.File"], h5_datasets_paths: List[str], ) -> None: """ Copy all dataset from :param input_h5 matching with paths in :param h5_datasets_paths into the :param output - :param input_h5: - :param output_h5: + :param input_h5: Path to HDF5 file, BytesIO, or already opened h5py.File + :param output_h5: Path to HDF5 file, BytesIO, or already opened h5py.File :param h5_datasets_paths: :return: """ if h5_datasets_paths is None: h5_datasets_paths = h5_list_datasets(input_h5) if len(h5_datasets_paths) > 0: - with h5py.File(output_h5, "a") as f_dest: # type: ignore - with h5py.File(input_h5, "r") as f_src: # type: ignore + # Handle output file + should_close_dest = not isinstance(output_h5, h5py.File) # type: ignore + f_dest = output_h5 if isinstance(output_h5, h5py.File) else h5py.File(output_h5, "a") # type: ignore + + try: + # Handle input file + should_close_src = not isinstance(input_h5, h5py.File) # type: ignore + f_src = input_h5 if isinstance(input_h5, h5py.File) else h5py.File(input_h5, "r") # type: ignore + + try: for dataset in h5_datasets_paths: f_dest.create_dataset(dataset, data=f_src[dataset]) + finally: + if should_close_src: + f_src.close() + finally: + if should_close_dest: + f_dest.close() @dataclass class HDF5FileWriter: def write_array( self, - target: Union[str, BytesIO, bytes], + target: Union[str, BytesIO, bytes, "h5py.File"], array: Union[list, np.ndarray], path_in_external_file: str, dtype: Optional[np.dtype] = None, @@ -119,32 +156,53 @@ def write_array( if dtype is not None and not isinstance(dtype, np.dtype): dtype = np.dtype(dtype) - with h5py.File(target, "a") as f: # type: ignore - # print(array.dtype, h5py.string_dtype(), array.dtype == 'O') - # print("\t", dtype or (h5py.string_dtype() if array.dtype == '0' else array.dtype)) + # Check if it's already an opened h5py.File + if isinstance(target, h5py.File): # type: ignore if isinstance(array, np.ndarray) and array.dtype == "O": array = np.asarray([s.encode() if isinstance(s, str) else s for s in array]) np.void(array) - dset = f.create_dataset(path_in_external_file, array.shape, dtype or array.dtype) + dset = target.create_dataset(path_in_external_file, array.shape, dtype or array.dtype) dset[()] = array + else: + with h5py.File(target, "a") as f: # type: ignore + # print(array.dtype, h5py.string_dtype(), array.dtype == 'O') + # print("\t", dtype or (h5py.string_dtype() if array.dtype == '0' else array.dtype)) + if isinstance(array, np.ndarray) and array.dtype == "O": + array = np.asarray([s.encode() if isinstance(s, str) else s for s in array]) + np.void(array) + dset = f.create_dataset(path_in_external_file, array.shape, dtype or array.dtype) + dset[()] = array else: class HDF5FileReader: - def read_array(self, source: Union[BytesIO, str], path_in_external_file: str) -> Optional[np.ndarray]: + def read_array(self, source: Union[BytesIO, str, Any], path_in_external_file: str) -> Optional[np.ndarray]: raise MissingExtraInstallation(extra_name="hdf5") - def get_array_dimension(self, source: Union[BytesIO, str], path_in_external_file: str) -> Optional[np.ndarray]: + def get_array_dimension( + self, source: Union[BytesIO, str, Any], path_in_external_file: str + ) -> Optional[np.ndarray]: raise MissingExtraInstallation(extra_name="hdf5") def extract_h5_datasets( self, - input_h5: Union[BytesIO, str], - output_h5: Union[BytesIO, str], + input_h5: Union[BytesIO, str, Any], + output_h5: Union[BytesIO, str, Any], h5_datasets_paths: List[str], ) -> None: raise MissingExtraInstallation(extra_name="hdf5") + class HDF5FileWriter: + + def write_array( + self, + target: Union[str, BytesIO, bytes, Any], + array: Union[list, np.ndarray], + path_in_external_file: str, + dtype: Optional[np.dtype] = None, + ): + raise MissingExtraInstallation(extra_name="hdf5") + # APACHE PARQUET if __PARQUET_MODULE_EXISTS__: diff --git a/energyml-utils/src/energyml/utils/data/export.py b/energyml-utils/src/energyml/utils/data/export.py new file mode 100644 index 0000000..48d9681 --- /dev/null +++ b/energyml-utils/src/energyml/utils/data/export.py @@ -0,0 +1,489 @@ +# Copyright (c) 2023-2024 Geosiris. +# SPDX-License-Identifier: Apache-2.0 +""" +Module for exporting mesh data to various file formats. +Supports OBJ, GeoJSON, VTK, and STL formats. +""" + +import json +import struct +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, BinaryIO, List, Optional, TextIO, Union + +import numpy as np + +if TYPE_CHECKING: + from .mesh import AbstractMesh + + +class ExportFormat(Enum): + """Supported mesh export formats.""" + + OBJ = "obj" + GEOJSON = "geojson" + VTK = "vtk" + STL = "stl" + + @classmethod + def from_extension(cls, extension: str) -> "ExportFormat": + """Get format from file extension.""" + ext = extension.lower().lstrip(".") + for fmt in cls: + if fmt.value == ext: + return fmt + raise ValueError(f"Unsupported file extension: {extension}") + + @classmethod + def all_extensions(cls) -> List[str]: + """Get all supported file extensions.""" + return [fmt.value for fmt in cls] + + +class ExportOptions: + """Base class for export options.""" + + pass + + +class STLExportOptions(ExportOptions): + """Options for STL export.""" + + def __init__(self, binary: bool = True, ascii_precision: int = 6): + """ + Initialize STL export options. + + :param binary: If True, export as binary STL; if False, export as ASCII STL + :param ascii_precision: Number of decimal places for ASCII format + """ + self.binary = binary + self.ascii_precision = ascii_precision + + +class VTKExportOptions(ExportOptions): + """Options for VTK export.""" + + def __init__(self, binary: bool = False, dataset_name: str = "mesh"): + """ + Initialize VTK export options. + + :param binary: If True, export as binary VTK; if False, export as ASCII VTK + :param dataset_name: Name of the dataset in VTK file + """ + self.binary = binary + self.dataset_name = dataset_name + + +class GeoJSONExportOptions(ExportOptions): + """Options for GeoJSON export.""" + + def __init__(self, indent: Optional[int] = 2, properties: Optional[dict] = None): + """ + Initialize GeoJSON export options. + + :param indent: JSON indentation level (None for compact) + :param properties: Additional properties to include in features + """ + self.indent = indent + self.properties = properties or {} + + +def export_obj(mesh_list: List["AbstractMesh"], out: BinaryIO, obj_name: Optional[str] = None) -> None: + """ + Export mesh data to Wavefront OBJ format. + + :param mesh_list: List of AbstractMesh objects to export + :param out: Binary output stream + :param obj_name: Optional object name for the OBJ file + """ + # Lazy import to avoid circular dependency + from .mesh import PolylineSetMesh + + # Write header + out.write(b"# Generated by energyml-utils a Geosiris python module\n\n") + + # Write object name if provided + if obj_name is not None: + out.write(f"o {obj_name}\n\n".encode("utf-8")) + + point_offset = 0 + + for mesh in mesh_list: + # Write group name using mesh identifier or uuid + mesh_id = getattr(mesh, "identifier", None) or getattr(mesh, "uuid", "mesh") + out.write(f"g {mesh_id}\n\n".encode("utf-8")) + + # Write vertices + for point in mesh.point_list: + if len(point) > 0: + out.write(f"v {' '.join(map(str, point))}\n".encode("utf-8")) + + # Write faces or lines depending on mesh type + indices = mesh.get_indices() + elt_letter = "l" if isinstance(mesh, PolylineSetMesh) else "f" + + for face_or_line in indices: + if len(face_or_line) > 1: + # OBJ indices are 1-based + indices_str = " ".join(str(idx + point_offset + 1) for idx in face_or_line) + out.write(f"{elt_letter} {indices_str}\n".encode("utf-8")) + + point_offset += len(mesh.point_list) + + +def export_geojson( + mesh_list: List["AbstractMesh"], out: TextIO, options: Optional[GeoJSONExportOptions] = None +) -> None: + """ + Export mesh data to GeoJSON format. + + :param mesh_list: List of AbstractMesh objects to export + :param out: Text output stream + :param options: GeoJSON export options + """ + # Lazy import to avoid circular dependency + from .mesh import PolylineSetMesh, SurfaceMesh + + if options is None: + options = GeoJSONExportOptions() + + features = [] + + for mesh_idx, mesh in enumerate(mesh_list): + indices = mesh.get_indices() + + if isinstance(mesh, PolylineSetMesh): + # Export as LineString features + for line_idx, line_indices in enumerate(indices): + if len(line_indices) < 2: + continue + coordinates = [list(mesh.point_list[idx]) for idx in line_indices] + feature = { + "type": "Feature", + "geometry": {"type": "LineString", "coordinates": coordinates}, + "properties": {"mesh_index": mesh_idx, "line_index": line_idx, **options.properties}, + } + features.append(feature) + + elif isinstance(mesh, SurfaceMesh): + # Export as Polygon features + for face_idx, face_indices in enumerate(indices): + if len(face_indices) < 3: + continue + # GeoJSON Polygon requires closed ring (first point == last point) + coordinates = [list(mesh.point_list[idx]) for idx in face_indices] + coordinates.append(coordinates[0]) # Close the ring + + feature = { + "type": "Feature", + "geometry": {"type": "Polygon", "coordinates": [coordinates]}, + "properties": {"mesh_index": mesh_idx, "face_index": face_idx, **options.properties}, + } + features.append(feature) + + geojson = {"type": "FeatureCollection", "features": features} + + json.dump(geojson, out, indent=options.indent) + + +def export_vtk(mesh_list: List["AbstractMesh"], out: BinaryIO, options: Optional[VTKExportOptions] = None) -> None: + """ + Export mesh data to VTK legacy format. + + :param mesh_list: List of AbstractMesh objects to export + :param out: Binary output stream + :param options: VTK export options + """ + # Lazy import to avoid circular dependency + from .mesh import PolylineSetMesh, SurfaceMesh + + if options is None: + options = VTKExportOptions() + + # Combine all meshes + all_points = [] + all_polygons = [] + all_lines = [] + vertex_offset = 0 + + for mesh in mesh_list: + all_points.extend(mesh.point_list) + indices = mesh.get_indices() + + if isinstance(mesh, SurfaceMesh): + # Adjust face indices + for face in indices: + adjusted_face = [idx + vertex_offset for idx in face] + all_polygons.append(adjusted_face) + elif isinstance(mesh, PolylineSetMesh): + # Adjust line indices + for line in indices: + adjusted_line = [idx + vertex_offset for idx in line] + all_lines.append(adjusted_line) + + vertex_offset += len(mesh.point_list) + + # Write VTK header + out.write(b"# vtk DataFile Version 3.0\n") + out.write(f"{options.dataset_name}\n".encode("utf-8")) + out.write(b"ASCII\n") + out.write(b"DATASET POLYDATA\n") + + # Write points + out.write(f"POINTS {len(all_points)} float\n".encode("utf-8")) + for point in all_points: + out.write(f"{point[0]} {point[1]} {point[2]}\n".encode("utf-8")) + + # Write polygons + if all_polygons: + total_poly_size = sum(len(poly) + 1 for poly in all_polygons) + out.write(f"POLYGONS {len(all_polygons)} {total_poly_size}\n".encode("utf-8")) + for poly in all_polygons: + out.write(f"{len(poly)} {' '.join(str(idx) for idx in poly)}\n".encode("utf-8")) + + # Write lines + if all_lines: + total_line_size = sum(len(line) + 1 for line in all_lines) + out.write(f"LINES {len(all_lines)} {total_line_size}\n".encode("utf-8")) + for line in all_lines: + out.write(f"{len(line)} {' '.join(str(idx) for idx in line)}\n".encode("utf-8")) + + +def export_stl(mesh_list: List["AbstractMesh"], out: BinaryIO, options: Optional[STLExportOptions] = None) -> None: + """ + Export mesh data to STL format (binary or ASCII). + + Note: STL format only supports triangles. Only triangular faces will be exported. + + :param mesh_list: List of AbstractMesh objects to export + :param out: Binary output stream + :param options: STL export options + """ + # Lazy import to avoid circular dependency + from .mesh import SurfaceMesh + + if options is None: + options = STLExportOptions(binary=True) + + # Collect all triangles (only from SurfaceMesh with triangular faces) + all_triangles = [] + for mesh in mesh_list: + if isinstance(mesh, SurfaceMesh): + indices = mesh.get_indices() + for face in indices: + # Only export triangular faces + if len(face) == 3: + p0 = np.array(mesh.point_list[face[0]]) + p1 = np.array(mesh.point_list[face[1]]) + p2 = np.array(mesh.point_list[face[2]]) + all_triangles.append((p0, p1, p2)) + + if options.binary: + _export_stl_binary(all_triangles, out) + else: + _export_stl_ascii(all_triangles, out, options.ascii_precision) + + +def _export_stl_binary(triangles: List[tuple], out: BinaryIO) -> None: + """Export STL in binary format.""" + # Write 80-byte header + header = b"Binary STL file generated by energyml-utils" + b"\0" * (80 - 44) + out.write(header) + + # Write number of triangles + out.write(struct.pack(" 0: + normal = normal / norm + else: + normal = np.array([0.0, 0.0, 0.0]) + + # Write normal + out.write(struct.pack(" None: + """Export STL in ASCII format.""" + out.write(b"solid mesh\n") + + for p0, p1, p2 in triangles: + # Calculate normal vector + v1 = p1 - p0 + v2 = p2 - p0 + normal = np.cross(v1, v2) + norm = np.linalg.norm(normal) + if norm > 0: + normal = normal / norm + else: + normal = np.array([0.0, 0.0, 0.0]) + + # Write facet + line = f" facet normal {normal[0]:.{precision}e} {normal[1]:.{precision}e} {normal[2]:.{precision}e}\n" + out.write(line.encode("utf-8")) + out.write(b" outer loop\n") + + for point in [p0, p1, p2]: + line = f" vertex {point[0]:.{precision}e} {point[1]:.{precision}e} {point[2]:.{precision}e}\n" + out.write(line.encode("utf-8")) + + out.write(b" endloop\n") + out.write(b" endfacet\n") + + out.write(b"endsolid mesh\n") + + +def export_mesh( + mesh_list: List["AbstractMesh"], + output_path: Union[str, Path], + format: Optional[ExportFormat] = None, + options: Optional[ExportOptions] = None, +) -> None: + """ + Export mesh data to a file in the specified format. + + :param mesh_list: List of Mesh objects to export + :param output_path: Output file path + :param format: Export format (auto-detected from extension if None) + :param options: Format-specific export options + """ + path = Path(output_path) + + # Auto-detect format from extension if not specified + if format is None: + format = ExportFormat.from_extension(path.suffix) + + # Determine if file should be opened in binary or text mode + binary_formats = {ExportFormat.OBJ, ExportFormat.STL, ExportFormat.VTK} + text_formats = {ExportFormat.GEOJSON} + + if format in binary_formats: + with path.open("wb") as f: + if format == ExportFormat.OBJ: + export_obj(mesh_list, f) + elif format == ExportFormat.STL: + export_stl(mesh_list, f, options) + elif format == ExportFormat.VTK: + export_vtk(mesh_list, f, options) + elif format in text_formats: + with path.open("w", encoding="utf-8") as f: + if format == ExportFormat.GEOJSON: + export_geojson(mesh_list, f, options) + else: + raise ValueError(f"Unsupported format: {format}") + + +# UI Helper Functions + + +def supported_formats() -> List[str]: + """ + Get list of supported export formats. + + :return: List of format names (e.g., ['obj', 'geojson', 'vtk', 'stl']) + """ + return ExportFormat.all_extensions() + + +def format_description(format: Union[str, ExportFormat]) -> str: + """ + Get human-readable description of a format. + + :param format: Format name or ExportFormat enum + :return: Description string + """ + if isinstance(format, str): + format = ExportFormat.from_extension(format) + + descriptions = { + ExportFormat.OBJ: "Wavefront OBJ - 3D geometry format (triangles and lines)", + ExportFormat.GEOJSON: "GeoJSON - Geographic data format (lines and polygons)", + ExportFormat.VTK: "VTK Legacy - Visualization Toolkit format", + ExportFormat.STL: "STL - Stereolithography format (triangles only)", + } + return descriptions.get(format, "Unknown format") + + +def format_filter_string(format: Union[str, ExportFormat]) -> str: + """ + Get file filter string for UI dialogs (Qt, tkinter, etc.). + + :param format: Format name or ExportFormat enum + :return: Filter string (e.g., "OBJ Files (*.obj)") + """ + if isinstance(format, str): + format = ExportFormat.from_extension(format) + + filters = { + ExportFormat.OBJ: "OBJ Files (*.obj)", + ExportFormat.GEOJSON: "GeoJSON Files (*.geojson)", + ExportFormat.VTK: "VTK Files (*.vtk)", + ExportFormat.STL: "STL Files (*.stl)", + } + return filters.get(format, "All Files (*.*)") + + +def all_formats_filter_string() -> str: + """ + Get file filter string for all supported formats. + Useful for Qt QFileDialog or similar UI components. + + :return: Filter string with all formats + """ + filters = [format_filter_string(fmt) for fmt in ExportFormat] + return ";;".join(filters) + + +def get_format_options_class(format: Union[str, ExportFormat]) -> Optional[type]: + """ + Get the options class for a specific format. + + :param format: Format name or ExportFormat enum + :return: Options class or None if no options available + """ + if isinstance(format, str): + format = ExportFormat.from_extension(format) + + options_map = { + ExportFormat.STL: STLExportOptions, + ExportFormat.VTK: VTKExportOptions, + ExportFormat.GEOJSON: GeoJSONExportOptions, + } + return options_map.get(format) + + +def supports_lines(format: Union[str, ExportFormat]) -> bool: + """ + Check if format supports line primitives. + + :param format: Format name or ExportFormat enum + :return: True if format supports lines + """ + if isinstance(format, str): + format = ExportFormat.from_extension(format) + + return format in {ExportFormat.OBJ, ExportFormat.GEOJSON, ExportFormat.VTK} + + +def supports_triangles(format: Union[str, ExportFormat]) -> bool: + """ + Check if format supports triangle primitives. + + :param format: Format name or ExportFormat enum + :return: True if format supports triangles + """ + # All formats support triangles + return True diff --git a/energyml-utils/src/energyml/utils/data/helper.py b/energyml-utils/src/energyml/utils/data/helper.py index febba46..9ebde1d 100644 --- a/energyml-utils/src/energyml/utils/data/helper.py +++ b/energyml-utils/src/energyml/utils/data/helper.py @@ -5,13 +5,14 @@ import sys from typing import Any, Optional, Callable, List, Union +from energyml.utils.storage_interface import EnergymlStorageInterface import numpy as np from .datasets_io import read_external_dataset_array from ..constants import flatten_concatenation -from ..epc import get_obj_identifier from ..exception import ObjectNotFoundNotError from ..introspection import ( + get_obj_uri, snake_case, get_object_attribute_no_verif, search_attribute_matching_name_with_path, @@ -21,7 +22,7 @@ get_object_attribute, get_object_attribute_rgx, ) -from ..workspace import EnergymlWorkspace + from .datasets_io import get_path_in_external_with_path _ARRAY_NAMES_ = [ @@ -86,20 +87,29 @@ def is_z_reversed(crs: Optional[Any]) -> bool: """ reverse_z_values = False if crs is not None: - # resqml 201 - zincreasing_downward = search_attribute_matching_name(crs, "ZIncreasingDownward") - if len(zincreasing_downward) > 0: - reverse_z_values = zincreasing_downward[0] - - # resqml >= 22 - vert_axis = search_attribute_matching_name(crs, "VerticalAxis.Direction") - if len(vert_axis) > 0: - vert_axis_str = str(vert_axis[0]) - if "." in vert_axis_str: - vert_axis_str = vert_axis_str.split(".")[-1] - - reverse_z_values = vert_axis_str.lower() == "down" - + if "VerticalCrs" in type(crs).__name__: + vert_axis = search_attribute_matching_name(crs, "Direction") + if len(vert_axis) > 0: + vert_axis_str = str(vert_axis[0]) + if "." in vert_axis_str: + vert_axis_str = vert_axis_str.split(".")[-1] + + reverse_z_values = vert_axis_str.lower() == "down" + else: + # resqml 201 + zincreasing_downward = search_attribute_matching_name(crs, "ZIncreasingDownward") + if len(zincreasing_downward) > 0: + reverse_z_values = zincreasing_downward[0] + + # resqml >= 22 + vert_axis = search_attribute_matching_name(crs, "VerticalAxis.Direction") + if len(vert_axis) > 0: + vert_axis_str = str(vert_axis[0]) + if "." in vert_axis_str: + vert_axis_str = vert_axis_str.split(".")[-1] + + reverse_z_values = vert_axis_str.lower() == "down" + logging.debug(f"is_z_reversed: {reverse_z_values}") return reverse_z_values @@ -114,7 +124,7 @@ def get_vertical_epsg_code(crs_object: Any): return vertical_epsg_code -def get_projected_epsg_code(crs_object: Any, workspace: Optional[EnergymlWorkspace] = None): +def get_projected_epsg_code(crs_object: Any, workspace: Optional[EnergymlStorageInterface] = None): if crs_object is not None: # LocalDepth3dCRS projected_epsg_code = get_object_attribute_rgx(crs_object, "ProjectedCrs.EpsgCode") if projected_epsg_code is None: # LocalEngineering2DCrs @@ -130,7 +140,7 @@ def get_projected_epsg_code(crs_object: Any, workspace: Optional[EnergymlWorkspa return None -def get_projected_uom(crs_object: Any, workspace: Optional[EnergymlWorkspace] = None): +def get_projected_uom(crs_object: Any, workspace: Optional[EnergymlStorageInterface] = None): if crs_object is not None: projected_epsg_uom = get_object_attribute_rgx(crs_object, "ProjectedUom") if projected_epsg_uom is None: @@ -144,7 +154,7 @@ def get_projected_uom(crs_object: Any, workspace: Optional[EnergymlWorkspace] = return None -def get_crs_origin_offset(crs_obj: Any) -> List[float]: +def get_crs_origin_offset(crs_obj: Any) -> List[float | int]: """ Return a list [X,Y,Z] corresponding to the crs Offset [XOffset/OriginProjectedCoordinate1, ... ] depending on the crs energyml version. @@ -163,12 +173,12 @@ def get_crs_origin_offset(crs_obj: Any) -> List[float]: if tmp_offset_z is None: tmp_offset_z = get_object_attribute_rgx(crs_obj, "OriginProjectedCoordinate3") - crs_point_offset = [0, 0, 0] + crs_point_offset = [0.0, 0.0, 0.0] try: crs_point_offset = [ - float(tmp_offset_x) if tmp_offset_x is not None else 0, - float(tmp_offset_y) if tmp_offset_y is not None else 0, - float(tmp_offset_z) if tmp_offset_z is not None else 0, + float(tmp_offset_x) if tmp_offset_x is not None else 0.0, + float(tmp_offset_y) if tmp_offset_y is not None else 0.0, + float(tmp_offset_z) if tmp_offset_z is not None else 0.0, ] except Exception as e: logging.info(f"ERR reading crs offset {e}") @@ -183,30 +193,66 @@ def prod_n_tab(val: Union[float, int, str], tab: List[Union[float, int, str]]): :param tab: :return: """ - return list(map(lambda x: x * val, tab)) + if val is None: + return [None] * len(tab) + logging.debug(f"Multiplying list by {val}: {tab}") + # Convert to numpy array for vectorized operations, handling None values + arr = np.array(tab, dtype=object) + logging.debug(f"arr: {arr}") + # Create mask for non-None values + mask = arr != None # noqa: E711 + # Create result array filled with None + result = np.full(len(tab), None, dtype=object) + logging.debug(f"result before multiplication: {result}") + # Multiply only non-None values + result[mask] = arr[mask].astype(float) * val + logging.debug(f"result after multiplication: {result}") + return result.tolist() def sum_lists(l1: List, l2: List): """ - Sums 2 lists values. + Sums 2 lists values, preserving None values. Example: [1,1,1] and [2,2,3,6] gives : [3,3,4,6] + [1,None,3] and [2,2,3] gives : [3,None,6] :param l1: :param l2: :return: """ - return [l1[i] + l2[i] for i in range(min(len(l1), len(l2)))] + max(l1, l2, key=len)[ - min(len(l1), len(l2)) : # noqa: E203 - ] + min_len = min(len(l1), len(l2)) + + # Convert to numpy arrays for vectorized operations + arr1 = np.array(l1[:min_len], dtype=object) + arr2 = np.array(l2[:min_len], dtype=object) + + # Create result array + result = np.full(min_len, None, dtype=object) + + # Find indices where both values are not None + mask = (arr1 != None) & (arr2 != None) # noqa: E711 + + # Sum only where both are not None + if np.any(mask): + result[mask] = arr1[mask].astype(float) + arr2[mask].astype(float) + + # Convert back to list and append remaining elements from longer list + result_list = result.tolist() + if len(l1) > min_len: + result_list.extend(l1[min_len:]) + elif len(l2) > min_len: + result_list.extend(l2[min_len:]) + + return result_list def get_crs_obj( context_obj: Any, path_in_root: Optional[str] = None, root_obj: Optional[Any] = None, - workspace: Optional[EnergymlWorkspace] = None, + workspace: Optional[EnergymlStorageInterface] = None, ) -> Optional[Any]: """ Search for the CRS object related to :param:`context_obj` into the :param:`workspace` @@ -222,12 +268,12 @@ def get_crs_obj( crs_list = search_attribute_matching_name(context_obj, r"\.*Crs", search_in_sub_obj=True, deep_search=False) if crs_list is not None and len(crs_list) > 0: # logging.debug(crs_list[0]) - crs = workspace.get_object_by_identifier(get_obj_identifier(crs_list[0])) + crs = workspace.get_object(get_obj_uri(crs_list[0])) if crs is None: crs = workspace.get_object_by_uuid(get_obj_uuid(crs_list[0])) if crs is None: logging.error(f"CRS {crs_list[0]} not found (or not read correctly)") - raise ObjectNotFoundNotError(get_obj_identifier(crs_list[0])) + raise ObjectNotFoundNotError(get_obj_uri(crs_list[0])) if crs is not None: return crs @@ -293,9 +339,9 @@ def read_external_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, -) -> Union[List[Any], np.ndarray]: + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, +) -> Optional[Union[List[Any], np.ndarray]]: """ Read an external array (BooleanExternalArray, BooleanHdf5Array, DoubleHdf5Array, IntegerHdf5Array, StringExternalArray ...) :param energyml_array: @@ -333,10 +379,11 @@ def read_external_array( ) if sub_indices is not None and len(sub_indices) > 0: - res = [] - for idx in sub_indices: - res.append(array[idx]) - array = res + if isinstance(array, np.ndarray): + array = array[sub_indices] + elif isinstance(array, list): + # Fallback for non-numpy arrays + array = [array[idx] for idx in sub_indices] return array @@ -357,9 +404,9 @@ def read_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, -) -> List[Any]: + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, +) -> Union[List[Any], np.ndarray]: """ Read an array and return a list. The array is read depending on its type. see. :py:func:`energyml.utils.data.helper.get_supported_array` :param energyml_array: @@ -393,8 +440,8 @@ def read_constant_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: Optional[List[int]] = None, + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List[Any]: """ Read a constant array ( BooleanConstantArray, DoubleConstantArray, FloatingPointConstantArray, IntegerConstantArray ...) @@ -423,9 +470,9 @@ def read_xml_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, -) -> List[Any]: + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, +) -> Union[List[Any], np.ndarray]: """ Read a xml array ( BooleanXmlArray, FloatingPointXmlArray, IntegerXmlArray, StringXmlArray ...) :param energyml_array: @@ -439,10 +486,11 @@ def read_xml_array( # count = get_object_attribute_no_verif(energyml_array, "count_per_value") if sub_indices is not None and len(sub_indices) > 0: - res = [] - for idx in sub_indices: - res.append(values[idx]) - values = res + if isinstance(values, np.ndarray): + values = values[sub_indices] + elif isinstance(values, list): + # Use list comprehension for efficiency + values = [values[idx] for idx in sub_indices] return values @@ -450,8 +498,8 @@ def read_jagged_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List[Any]: """ Read a jagged array @@ -465,27 +513,23 @@ def read_jagged_array( elements = read_array( energyml_array=get_object_attribute_no_verif(energyml_array, "elements"), root_obj=root_obj, - path_in_root=path_in_root + ".elements", + path_in_root=(path_in_root or "") + ".elements", workspace=workspace, ) cumulative_length = read_array( energyml_array=read_array(get_object_attribute_no_verif(energyml_array, "cumulative_length")), root_obj=root_obj, - path_in_root=path_in_root + ".cumulative_length", + path_in_root=(path_in_root or "") + ".cumulative_length", workspace=workspace, ) - array = [] - previous = 0 - for cl in cumulative_length: - array.append(elements[previous:cl]) - previous = cl + # Use list comprehension for better performance + array = [ + elements[cumulative_length[i - 1] if i > 0 else 0 : cumulative_length[i]] for i in range(len(cumulative_length)) + ] if sub_indices is not None and len(sub_indices) > 0: - res = [] - for idx in sub_indices: - res.append(array[idx]) - array = res + array = [array[idx] for idx in sub_indices] return array @@ -493,8 +537,8 @@ def read_int_double_lattice_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ): """ Read DoubleLatticeArray or IntegerLatticeArray. @@ -505,27 +549,33 @@ def read_int_double_lattice_array( :param sub_indices: :return: """ - # start_value = get_object_attribute_no_verif(energyml_array, "start_value") + start_value = get_object_attribute_no_verif(energyml_array, "start_value") offset = get_object_attribute_no_verif(energyml_array, "offset") - # result = [] + result = [] + + if len(offset) == 1: + # 1D lattice array: offset is a single DoubleConstantArray or IntegerConstantArray + offset_obj = offset[0] + + # Get the offset value and count from the ConstantArray + offset_value = get_object_attribute_no_verif(offset_obj, "value") + count = get_object_attribute_no_verif(offset_obj, "count") - # if len(offset) == 1: - # pass - # elif len(offset) == 2: - # pass - # else: - raise Exception(f"{type(energyml_array)} read with an offset of length {len(offset)} is not supported") + # Generate the 1D array: start_value + i * offset_value for i in range(count) + result = [start_value + i * offset_value for i in range(count)] + else: + raise Exception(f"{type(energyml_array)} read with an offset of length {len(offset)} is not supported") - # return result + return result def read_point3d_zvalue_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ): """ Read a Point3D2ValueArray @@ -540,7 +590,7 @@ def read_point3d_zvalue_array( sup_geom_array = read_array( energyml_array=supporting_geometry, root_obj=root_obj, - path_in_root=path_in_root + ".SupportingGeometry", + path_in_root=(path_in_root or "") + ".SupportingGeometry", workspace=workspace, sub_indices=sub_indices, ) @@ -550,21 +600,32 @@ def read_point3d_zvalue_array( read_array( energyml_array=zvalues, root_obj=root_obj, - path_in_root=path_in_root + ".ZValues", + path_in_root=(path_in_root or "") + ".ZValues", workspace=workspace, sub_indices=sub_indices, ) ) - count = 0 + # Use NumPy for vectorized operation if possible + error_logged = False - for i in range(len(sup_geom_array)): - try: - sup_geom_array[i][2] = zvalues_array[i] - except Exception as e: - if count == 0: - logging.error(e, f": {i} is out of bound of {len(zvalues_array)}") - count = count + 1 + if isinstance(sup_geom_array, np.ndarray) and isinstance(zvalues_array, np.ndarray): + # Vectorized assignment for NumPy arrays + min_len = min(len(sup_geom_array), len(zvalues_array)) + if min_len < len(sup_geom_array): + logging.warning( + f"Z-values array ({len(zvalues_array)}) is shorter than geometry array ({len(sup_geom_array)}), only updating first {min_len} values" + ) + sup_geom_array[:min_len, 2] = zvalues_array[:min_len] + else: + # Fallback for list-based arrays + for i in range(len(sup_geom_array)): + try: + sup_geom_array[i][2] = zvalues_array[i] + except (IndexError, TypeError) as e: + if not error_logged: + logging.error(f"{type(e).__name__}: index {i} is out of bound of {len(zvalues_array)}") + error_logged = True return sup_geom_array @@ -573,8 +634,8 @@ def read_point3d_from_representation_lattice_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ): """ Read a Point3DFromRepresentationLatticeArray. @@ -588,11 +649,9 @@ def read_point3d_from_representation_lattice_array( :param sub_indices: :return: """ - supporting_rep_identifier = get_obj_identifier( - get_object_attribute_no_verif(energyml_array, "supporting_representation") - ) + supporting_rep_identifier = get_obj_uri(get_object_attribute_no_verif(energyml_array, "supporting_representation")) # logging.debug(f"energyml_array : {energyml_array}\n\t{supporting_rep_identifier}") - supporting_rep = workspace.get_object_by_identifier(supporting_rep_identifier) + supporting_rep = workspace.get_object(supporting_rep_identifier) if workspace is not None else None # TODO chercher un pattern \.*patch\.*.[d]+ pour trouver le numero du patch dans le path_in_root puis lire le patch # logging.debug(f"path_in_root {path_in_root}") @@ -616,15 +675,15 @@ def read_grid2d_patch( patch: Any, grid2d: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, -) -> List: + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, +) -> Union[List, np.ndarray]: points_path, points_obj = search_attribute_matching_name_with_path(patch, "Geometry.Points")[0] return read_array( energyml_array=points_obj, root_obj=grid2d, - path_in_root=path_in_root + "." + points_path, + path_in_root=path_in_root + "." + points_path if path_in_root else points_path, workspace=workspace, sub_indices=sub_indices, ) @@ -634,8 +693,8 @@ def read_point3d_lattice_array( energyml_array: Any, root_obj: Optional[Any] = None, path_in_root: Optional[str] = None, - workspace: Optional[EnergymlWorkspace] = None, - sub_indices: List[int] = None, + workspace: Optional[EnergymlStorageInterface] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List: """ Read a Point3DLatticeArray. @@ -661,14 +720,14 @@ def read_point3d_lattice_array( obj=energyml_array, name_rgx="slowestAxisCount", root_obj=root_obj, - current_path=path_in_root, + current_path=path_in_root or "", ) crs_fa_count = search_attribute_in_upper_matching_name( obj=energyml_array, name_rgx="fastestAxisCount", root_obj=root_obj, - current_path=path_in_root, + current_path=path_in_root or "", ) crs = None @@ -695,7 +754,11 @@ def read_point3d_lattice_array( slowest_size = len(slowest_table) fastest_size = len(fastest_table) - if len(crs_sa_count) > 0 and len(crs_fa_count) > 0: + logging.debug(f"slowest vector: {slowest_vec}, spacing: {slowest_spacing}, size: {slowest_size}") + logging.debug(f"fastest vector: {fastest_vec}, spacing: {fastest_spacing}, size: {fastest_size}") + logging.debug(f"origin: {origin}, zincreasing_downward: {zincreasing_downward}") + + if crs_sa_count is not None and len(crs_sa_count) > 0 and crs_fa_count is not None and len(crs_fa_count) > 0: if (crs_sa_count[0] == fastest_size and crs_fa_count[0] == slowest_size) or ( crs_sa_count[0] == fastest_size - 1 and crs_fa_count[0] == slowest_size - 1 ): @@ -712,40 +775,74 @@ def read_point3d_lattice_array( slowest_size = crs_sa_count[0] fastest_size = crs_fa_count[0] - for i in range(slowest_size): - for j in range(fastest_size): - previous_value = origin - # to avoid a sum of the parts of the array at each iteration, I take the previous value in the same line - # number i and add the fastest_table[j] value - - if j > 0: - if i > 0: - line_idx = i * fastest_size # numero de ligne - previous_value = result[line_idx + j - 1] - else: - previous_value = result[j - 1] - if zincreasing_downward: - result.append(sum_lists(previous_value, slowest_table[i - 1])) - else: - result.append(sum_lists(previous_value, fastest_table[j - 1])) - else: - if i > 0: - prev_line_idx = (i - 1) * fastest_size # numero de ligne precedent - previous_value = result[prev_line_idx] - if zincreasing_downward: - result.append(sum_lists(previous_value, fastest_table[j - 1])) + # Vectorized approach using NumPy for massive performance improvement + try: + # Convert tables to NumPy arrays + origin_arr = np.array(origin, dtype=float) + slowest_arr = np.array(slowest_table, dtype=float) # shape: (slowest_size, 3) + fastest_arr = np.array(fastest_table, dtype=float) # shape: (fastest_size, 3) + + # Compute cumulative sums + slowest_cumsum = np.cumsum(slowest_arr, axis=0) # cumulative offset along slowest axis + fastest_cumsum = np.cumsum(fastest_arr, axis=0) # cumulative offset along fastest axis + + # Create meshgrid indices + i_indices, j_indices = np.meshgrid(np.arange(slowest_size), np.arange(fastest_size), indexing="ij") + + # Initialize result array + result_arr = np.zeros((slowest_size, fastest_size, 3), dtype=float) + result_arr[:, :, :] = origin_arr # broadcast origin to all positions + + # Add offsets based on zincreasing_downward + if zincreasing_downward: + # Add slowest offsets where i > 0 + result_arr[1:, :, :] += slowest_cumsum[:-1, np.newaxis, :] + # Add fastest offsets where j > 0 + result_arr[:, 1:, :] += fastest_cumsum[np.newaxis, :-1, :] + else: + # Add fastest offsets where j > 0 + result_arr[:, 1:, :] += fastest_cumsum[np.newaxis, :-1, :] + # Add slowest offsets where i > 0 + result_arr[1:, :, :] += slowest_cumsum[:-1, np.newaxis, :] + + # Flatten to list of points + result = result_arr.reshape(-1, 3).tolist() + + except (ValueError, TypeError) as e: + # Fallback to original implementation if NumPy conversion fails + logging.warning(f"NumPy vectorization failed ({e}), falling back to iterative approach") + for i in range(slowest_size): + for j in range(fastest_size): + previous_value = origin + + if j > 0: + if i > 0: + line_idx = i * fastest_size + previous_value = result[line_idx + j - 1] else: + previous_value = result[j - 1] + if zincreasing_downward: result.append(sum_lists(previous_value, slowest_table[i - 1])) + else: + result.append(sum_lists(previous_value, fastest_table[j - 1])) else: - result.append(previous_value) + if i > 0: + prev_line_idx = (i - 1) * fastest_size + previous_value = result[prev_line_idx] + if zincreasing_downward: + result.append(sum_lists(previous_value, fastest_table[j - 1])) + else: + result.append(sum_lists(previous_value, slowest_table[i - 1])) + else: + result.append(previous_value) else: raise Exception(f"{type(energyml_array)} read with an offset of length {len(offset)} is not supported") if sub_indices is not None and len(sub_indices) > 0: - res = [] - for idx in sub_indices: - res.append(result[idx]) - result = res + if isinstance(result, np.ndarray): + result = result[sub_indices].tolist() + else: + result = [result[idx] for idx in sub_indices] return result @@ -754,6 +851,6 @@ def read_point3d_lattice_array( # energyml_array: Any, # root_obj: Optional[Any] = None, # path_in_root: Optional[str] = None, -# workspace: Optional[EnergymlWorkspace] = None +# workspace: Optional[EnergymlStorageInterface] = None # ): # logging.debug(energyml_array) diff --git a/energyml-utils/src/energyml/utils/data/mesh.py b/energyml-utils/src/energyml/utils/data/mesh.py index 3ee9409..108da7e 100644 --- a/energyml-utils/src/energyml/utils/data/mesh.py +++ b/energyml-utils/src/energyml/utils/data/mesh.py @@ -16,25 +16,47 @@ from .helper import ( read_array, read_grid2d_patch, - EnergymlWorkspace, get_crs_obj, get_crs_origin_offset, is_z_reversed, ) -from ..epc import Epc, get_obj_identifier, gen_energyml_object_path -from ..epc_stream import EpcStreamReader -from ..exception import ObjectNotFoundNotError -from ..introspection import ( +from energyml.utils.epc import gen_energyml_object_path +from energyml.utils.epc_stream import EpcStreamReader +from energyml.utils.exception import NotSupportedError, ObjectNotFoundNotError +from energyml.utils.introspection import ( + get_obj_uri, search_attribute_matching_name, search_attribute_matching_name_with_path, snake_case, get_object_attribute, + get_object_attribute_rgx, ) +from energyml.utils.storage_interface import EnergymlStorageInterface + + +# Import export functions from new export module for backward compatibility +from .export import export_obj as _export_obj_new _FILE_HEADER: bytes = b"# file exported by energyml-utils python module (Geosiris)\n" Point = list[float] +# ============================ +# TODO : + +# obj_GridConnectionSetRepresentation +# obj_IjkGridRepresentation +# obj_PlaneSetRepresentation +# obj_RepresentationSetRepresentation +# obj_SealedSurfaceFrameworkRepresentation +# obj_SealedVolumeFrameworkRepresentation +# obj_SubRepresentation +# obj_UnstructuredGridRepresentation +# obj_WellboreMarkerFrameRepresentation +# obj_WellboreTrajectoryRepresentation + +# ============================ + class MeshFileFormat(Enum): OFF = "off" @@ -77,12 +99,12 @@ class AbstractMesh: crs_object: Any = field(default=None) - point_list: List[Point] = field( + point_list: Union[List[Point], np.ndarray] = field( default_factory=list, ) identifier: str = field( - default=None, + default="", ) def get_nb_edges(self) -> int: @@ -91,7 +113,7 @@ def get_nb_edges(self) -> int: def get_nb_faces(self) -> int: return 0 - def get_indices(self) -> List[List[int]]: + def get_indices(self) -> Union[List[List[int]], np.ndarray]: return [] @@ -102,7 +124,7 @@ class PointSetMesh(AbstractMesh): @dataclass class PolylineSetMesh(AbstractMesh): - line_indices: List[List[int]] = field( + line_indices: Union[List[List[int]], np.ndarray] = field( default_factory=list, ) @@ -112,13 +134,13 @@ def get_nb_edges(self) -> int: def get_nb_faces(self) -> int: return 0 - def get_indices(self) -> List[List[int]]: + def get_indices(self) -> Union[List[List[int]], np.ndarray]: return self.line_indices @dataclass class SurfaceMesh(AbstractMesh): - faces_indices: List[List[int]] = field( + faces_indices: Union[List[List[int]], np.ndarray] = field( default_factory=list, ) @@ -128,7 +150,7 @@ def get_nb_edges(self) -> int: def get_nb_faces(self) -> int: return len(self.faces_indices) - def get_indices(self) -> List[List[int]]: + def get_indices(self) -> Union[List[List[int]], np.ndarray]: return self.faces_indices @@ -145,7 +167,7 @@ def crs_displacement(points: List[Point], crs_obj: Any) -> Tuple[List[Point], Po if crs_point_offset != [0, 0, 0]: for p in points: for xyz in range(len(p)): - p[xyz] = p[xyz] + crs_point_offset[xyz] + p[xyz] = (p[xyz] + crs_point_offset[xyz]) if p[xyz] is not None else None if zincreasing_downward and len(p) >= 3: p[2] = -p[2] @@ -178,9 +200,9 @@ def _mesh_name_mapping(array_type_name: str) -> str: def read_mesh_object( energyml_object: Any, - workspace: Optional[EnergymlWorkspace] = None, + workspace: Optional[EnergymlStorageInterface] = None, use_crs_displacement: bool = False, - sub_indices: List[int] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List[AbstractMesh]: """ Read and "meshable" object. If :param:`energyml_object` is not supported, an exception will be raised. @@ -190,28 +212,44 @@ def read_mesh_object( is used to translate the data with the CRS offsets :return: """ + if isinstance(energyml_object, list): return energyml_object array_type_name = _mesh_name_mapping(type(energyml_object).__name__) reader_func = get_mesh_reader_function(array_type_name) if reader_func is not None: + # logging.info(f"using function {reader_func} to read type {array_type_name}") surfaces: List[AbstractMesh] = reader_func( energyml_object=energyml_object, workspace=workspace, sub_indices=sub_indices ) - if use_crs_displacement: + if ( + use_crs_displacement and "wellbore" not in array_type_name.lower() + ): # WellboreFrameRep has allready the displacement applied + # TODO: the displacement should be done in each reader function to manage specific cases for s in surfaces: + print("CRS : ", s.crs_object.uuid if s.crs_object is not None else "None") crs_displacement(s.point_list, s.crs_object) return surfaces else: - logging.error(f"Type {array_type_name} is not supported: function read_{snake_case(array_type_name)} not found") - raise Exception( - f"Type {array_type_name} is not supported\n\t{energyml_object}: \n\tfunction read_{snake_case(array_type_name)} not found" + # logging.error(f"Type {array_type_name} is not supported: function read_{snake_case(array_type_name)} not found") + raise NotSupportedError( + f"Type {array_type_name} is not supported\n\tfunction read_{snake_case(array_type_name)} not found" ) +def read_ijk_grid_representation( + energyml_object: Any, + workspace: EnergymlStorageInterface, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, +) -> List[Any]: + raise NotSupportedError("IJKGrid representation reading is not supported yet.") + + def read_point_representation( - energyml_object: Any, workspace: EnergymlWorkspace, sub_indices: List[int] = None + energyml_object: Any, + workspace: EnergymlStorageInterface, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List[PointSetMesh]: # pt_geoms = search_attribute_matching_type(point_set, "AbstractGeometry") @@ -273,7 +311,9 @@ def read_point_representation( def read_polyline_representation( - energyml_object: Any, workspace: EnergymlWorkspace, sub_indices: List[int] = None + energyml_object: Any, + workspace: EnergymlStorageInterface, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List[PolylineSetMesh]: # pt_geoms = search_attribute_matching_type(point_set, "AbstractGeometry") @@ -364,7 +404,7 @@ def read_polyline_representation( if len(points) > 0: meshes.append( PolylineSetMesh( - identifier=f"{get_obj_identifier(energyml_object)}_patch{patch_idx}", + identifier=f"{get_obj_uri(energyml_object)}_patch{patch_idx}", energyml_object=energyml_object, crs_object=crs, point_list=points, @@ -381,9 +421,9 @@ def gen_surface_grid_geometry( energyml_object: Any, patch: Any, patch_path: Any, - workspace: Optional[EnergymlWorkspace] = None, + workspace: Optional[EnergymlStorageInterface] = None, keep_holes=False, - sub_indices: List[int] = None, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, offset: int = 0, ): points = read_grid2d_patch( @@ -392,6 +432,8 @@ def gen_surface_grid_geometry( path_in_root=patch_path, workspace=workspace, ) + logging.debug(f"Total points read: {len(points)}") + logging.debug(f"Sample points: {points[0:5]}") fa_count = search_attribute_matching_name(patch, "FastestAxisCount") if fa_count is None: @@ -430,7 +472,7 @@ def gen_surface_grid_geometry( sa_count = sa_count + 1 fa_count = fa_count + 1 - # logging.debug(f"sa_count {sa_count} fa_count {fa_count} : {sa_count*fa_count} - {len(points)} ") + logging.debug(f"sa_count {sa_count} fa_count {fa_count} : {sa_count * fa_count} - {len(points)} ") for sa in range(sa_count - 1): for fa in range(fa_count - 1): @@ -478,7 +520,10 @@ def gen_surface_grid_geometry( def read_grid2d_representation( - energyml_object: Any, workspace: Optional[EnergymlWorkspace] = None, keep_holes=False, sub_indices: List[int] = None + energyml_object: Any, + workspace: Optional[EnergymlStorageInterface] = None, + keep_holes=False, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List[SurfaceMesh]: # h5_reader = HDF5FileReader() meshes = [] @@ -516,7 +561,7 @@ def read_grid2d_representation( meshes.append( SurfaceMesh( - identifier=f"{get_obj_identifier(energyml_object)}_patch{patch_idx}", + identifier=f"{get_obj_uri(energyml_object)}_patch{patch_idx}", energyml_object=energyml_object, crs_object=crs, point_list=points, @@ -555,7 +600,7 @@ def read_grid2d_representation( ) meshes.append( SurfaceMesh( - identifier=f"{get_obj_identifier(energyml_object)}_patch{patch_idx}", + identifier=f"{get_obj_uri(energyml_object)}_patch{patch_idx}", energyml_object=energyml_object, crs_object=crs, point_list=points, @@ -568,8 +613,8 @@ def read_grid2d_representation( def read_triangulated_set_representation( energyml_object: Any, - workspace: EnergymlWorkspace, - sub_indices: List[int] = None, + workspace: EnergymlStorageInterface, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List[SurfaceMesh]: meshes = [] @@ -634,7 +679,7 @@ def read_triangulated_set_representation( total_size = total_size + len(triangles_list) meshes.append( SurfaceMesh( - identifier=f"{get_obj_identifier(energyml_object)}_patch{patch_idx}", + identifier=f"{get_obj_uri(energyml_object)}_patch{patch_idx}", energyml_object=energyml_object, crs_object=crs, point_list=point_list, @@ -647,19 +692,167 @@ def read_triangulated_set_representation( return meshes +def read_wellbore_frame_representation( + energyml_object: Any, + workspace: EnergymlStorageInterface, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, +) -> List[PolylineSetMesh]: + """ + Read a WellboreFrameRepresentation and construct a polyline mesh from the trajectory. + + :param energyml_object: The WellboreFrameRepresentation object + :param workspace: The EnergymlStorageInterface to access related objects + :param sub_indices: Optional list of indices to filter specific nodes + :return: List containing a single PolylineSetMesh representing the wellbore + """ + meshes = [] + + try: + # Read measured depths (NodeMd) + md_array = [] + try: + node_md_path, node_md_obj = search_attribute_matching_name_with_path(energyml_object, "NodeMd")[0] + md_array = read_array( + energyml_array=node_md_obj, + root_obj=energyml_object, + path_in_root=node_md_path, + workspace=workspace, + ) + if not isinstance(md_array, list): + md_array = md_array.tolist() if hasattr(md_array, "tolist") else list(md_array) + except (IndexError, AttributeError) as e: + logging.warning(f"Could not read NodeMd from wellbore frame: {e}") + return meshes + + # Get trajectory reference + trajectory_dor = search_attribute_matching_name(obj=energyml_object, name_rgx="Trajectory")[0] + trajectory_identifier = get_obj_uri(trajectory_dor) + trajectory_obj = workspace.get_object(trajectory_identifier) + + if trajectory_obj is None: + logging.error(f"Trajectory {trajectory_identifier} not found") + return meshes + + # CRS + crs = None + + # Get reference point (wellhead location) - try different attribute paths for different versions + head_x, head_y, head_z = 0.0, 0.0, 0.0 + z_is_up = True # Default assumption + + try: + # Try to get MdDatum (RESQML 2.0.1) or MdInterval.Datum (RESQML 2.2+) + md_datum_dor = None + try: + md_datum_dor = search_attribute_matching_name(obj=trajectory_obj, name_rgx=r"MdDatum")[0] + except IndexError: + try: + md_datum_dor = search_attribute_matching_name(obj=trajectory_obj, name_rgx=r"MdInterval.Datum")[0] + except IndexError: + pass + + if md_datum_dor is not None: + md_datum_identifier = get_obj_uri(md_datum_dor) + md_datum_obj = workspace.get_object(md_datum_identifier) + + if md_datum_obj is not None: + # Try to get coordinates from ReferencePointInACrs + try: + head_x = get_object_attribute_rgx(md_datum_obj, r"HorizontalCoordinates.Coordinate1") or 0.0 + head_y = get_object_attribute_rgx(md_datum_obj, r"HorizontalCoordinates.Coordinate2") or 0.0 + head_z = get_object_attribute_rgx(md_datum_obj, "VerticalCoordinate") or 0.0 + + # Get vertical CRS to determine z direction + try: + vcrs_dor = search_attribute_matching_name(obj=md_datum_obj, name_rgx="VerticalCrs")[0] + vcrs_identifier = get_obj_uri(vcrs_dor) + vcrs_obj = workspace.get_object(vcrs_identifier) + + if vcrs_obj is not None: + z_is_up = not is_z_reversed(vcrs_obj) + except (IndexError, AttributeError): + pass + except AttributeError: + pass + # Get CRS from trajectory geometry if available + try: + geometry_paths = search_attribute_matching_name_with_path(md_datum_obj, r"VerticalCrs") + if len(geometry_paths) > 0: + crs_dor_path, crs_dor = geometry_paths[0] + crs_identifier = get_obj_uri(crs_dor) + crs = workspace.get_object(crs_identifier) + except Exception as e: + logging.debug(f"Could not get CRS from trajectory: {e}") + except Exception as e: + logging.debug(f"Could not get reference point from trajectory: {e}") + + # Build wellbore path points - simple vertical projection from measured depths + # Note: This is a simplified representation. For accurate 3D trajectory, + # you would need to interpolate along the trajectory's control points. + points = [] + line_indices = [] + + for i, md in enumerate(md_array): + # Create point at (head_x, head_y, head_z +/- md) + # Apply z direction based on CRS + z_offset = md if z_is_up else -md + points.append([head_x, head_y, head_z + z_offset]) + + # Connect consecutive points + if i > 0: + line_indices.append([i - 1, i]) + + # Apply sub_indices filter if provided + if sub_indices is not None and len(sub_indices) > 0: + filtered_points = [] + filtered_indices = [] + index_map = {} + + for new_idx, old_idx in enumerate(sub_indices): + if 0 <= old_idx < len(points): + filtered_points.append(points[old_idx]) + index_map[old_idx] = new_idx + + for line in line_indices: + if line[0] in index_map and line[1] in index_map: + filtered_indices.append([index_map[line[0]], index_map[line[1]]]) + + points = filtered_points + line_indices = filtered_indices + + if len(points) > 0: + meshes.append( + PolylineSetMesh( + identifier=f"{get_obj_uri(energyml_object)}_wellbore", + energyml_object=energyml_object, + crs_object=crs, + point_list=points, + line_indices=line_indices, + ) + ) + + except Exception as e: + logging.error(f"Failed to read wellbore frame representation: {e}") + import traceback + + traceback.print_exc() + + return meshes + + def read_sub_representation( energyml_object: Any, - workspace: EnergymlWorkspace, - sub_indices: List[int] = None, + workspace: EnergymlStorageInterface, + sub_indices: Optional[Union[List[int], np.ndarray]] = None, ) -> List[AbstractMesh]: supporting_rep_dor = search_attribute_matching_name( obj=energyml_object, name_rgx=r"(SupportingRepresentation|RepresentedObject)" )[0] - supporting_rep_identifier = get_obj_identifier(supporting_rep_dor) - supporting_rep = workspace.get_object_by_identifier(supporting_rep_identifier) + supporting_rep_identifier = get_obj_uri(supporting_rep_dor) + supporting_rep = workspace.get_object(supporting_rep_identifier) total_size = 0 - all_indices = [] + all_indices = None for patch_path, patch_indices in search_attribute_matching_name_with_path( obj=energyml_object, name_rgx="SubRepresentationPatch.\\d+.ElementIndices.\\d+.Indices", @@ -690,7 +883,7 @@ def read_sub_representation( else: total_size = total_size + len(array) - all_indices = all_indices + array + all_indices = all_indices + array if all_indices is not None else array meshes = read_mesh_object( energyml_object=supporting_rep, workspace=workspace, @@ -698,7 +891,7 @@ def read_sub_representation( ) for m in meshes: - m.identifier = f"sub representation {get_obj_identifier(energyml_object)} of {m.identifier}" + m.identifier = f"sub representation {get_obj_uri(energyml_object)} of {m.identifier}" return meshes @@ -1250,31 +1443,17 @@ def export_obj(mesh_list: List[AbstractMesh], out: BytesIO, obj_name: Optional[s """ Export an :class:`AbstractMesh` into obj format. + This function is maintained for backward compatibility and delegates to the + export module. For new code, consider importing from energyml.utils.data.export. + Each AbstractMesh from the list :param:`mesh_list` will be placed into its own group. :param mesh_list: :param out: :param obj_name: :return: """ - out.write("# Generated by energyml-utils a Geosiris python module\n\n".encode("utf-8")) - - if obj_name is not None: - out.write(f"o {obj_name}\n\n".encode("utf-8")) - - point_offset = 0 - for m in mesh_list: - out.write(f"g {m.identifier}\n\n".encode("utf-8")) - _export_obj_elt( - off_point_part=out, - off_face_part=out, - points=m.point_list, - indices=m.get_indices(), - point_offset=point_offset, - colors=[], - elt_letter="l" if isinstance(m, PolylineSetMesh) else "f", - ) - point_offset = point_offset + len(m.point_list) - out.write("\n".encode("utf-8")) + # Delegate to the new export module + _export_obj_new(mesh_list, out, obj_name) def _export_obj_elt( diff --git a/energyml-utils/src/energyml/utils/epc.py b/energyml-utils/src/energyml/utils/epc.py index 28e7c1b..e44fe22 100644 --- a/energyml-utils/src/energyml/utils/epc.py +++ b/energyml-utils/src/energyml/utils/epc.py @@ -30,6 +30,7 @@ Keywords1, TargetMode, ) +from energyml.utils.storage_interface import DataArrayMetadata, EnergymlStorageInterface, ResourceMetadata import numpy as np from .uri import Uri, parse_uri from xsdata.formats.dataclass.models.generics import DerivedElement @@ -87,12 +88,11 @@ read_energyml_json_bytes, JSON_VERSION, ) -from .workspace import EnergymlWorkspace from .xml import is_energyml_content_type @dataclass -class Epc(EnergymlWorkspace): +class Epc(EnergymlStorageInterface): """ A class that represent an EPC file content """ @@ -125,6 +125,8 @@ class Epc(EnergymlWorkspace): default_factory=list, ) + force_h5_path: Optional[str] = field(default=None) + """ Additional rels for objects. Key is the object (same than in @energyml_objects) and value is a list of RelationShip. This can be used to link an HDF5 to an ExternalPartReference in resqml 2.0.1 @@ -429,6 +431,10 @@ def get_h5_file_paths(self, obj: Any) -> List[str]: Get all HDF5 file paths referenced in the EPC file (from rels to external resources) :return: list of HDF5 file paths """ + + if self.force_h5_path is not None: + return [self.force_h5_path] + is_uri = (isinstance(obj, str) and parse_uri(obj) is not None) or isinstance(obj, Uri) if is_uri: obj = self.get_object_by_identifier(obj) @@ -452,8 +458,6 @@ def get_h5_file_paths(self, obj: Any) -> List[str]: h5_paths.add(possible_h5_path) return list(h5_paths) - # -- Functions inherited from EnergymlWorkspace - def get_object_as_dor(self, identifier: str, dor_qualified_type) -> Optional[Any]: """ Search an object by its identifier and returns a DOR @@ -487,8 +491,8 @@ def get_object_by_identifier(self, identifier: Union[str, Uri]) -> Optional[Any] return o return None - def get_object(self, uuid: str, object_version: Optional[str]) -> Optional[Any]: - return self.get_object_by_identifier(f"{uuid}.{object_version or ''}") + def get_object(self, identifier: Union[str, Uri]) -> Optional[Any]: + return self.get_object_by_identifier(identifier) def add_object(self, obj: Any) -> bool: """ @@ -634,11 +638,12 @@ def write_array( # Class methods @classmethod - def read_file(cls, epc_file_path: str): + def read_file(cls, epc_file_path: str) -> "Epc": with open(epc_file_path, "rb") as f: epc = cls.read_stream(BytesIO(f.read())) epc.epc_file_path = epc_file_path return epc + raise IOError(f"Failed to open EPC file {epc_file_path}") @classmethod def read_stream(cls, epc_file_io: BytesIO): # returns an Epc instance @@ -770,6 +775,45 @@ def read_stream(cls, epc_file_io: BytesIO): # returns an Epc instance return None + def list_objects(self, dataspace: str | None = None, object_type: str | None = None) -> List[ResourceMetadata]: + result = [] + for obj in self.energyml_objects: + if (dataspace is None or get_obj_type(get_obj_usable_class(obj)) == dataspace) and ( + object_type is None or get_qualified_type_from_class(type(obj)) == object_type + ): + res_meta = ResourceMetadata( + uri=str(get_obj_uri(obj)), + uuid=get_obj_uuid(obj), + title=get_object_attribute(obj, "citation.title") or "", + object_type=type(obj).__name__, + version=get_obj_version(obj), + content_type=get_content_type_from_class(type(obj)) or "", + ) + result.append(res_meta) + return result + + def put_object(self, obj: Any, dataspace: str | None = None) -> str | None: + if self.add_object(obj): + return str(get_obj_uri(obj)) + return None + + def delete_object(self, identifier: Union[str, Any]) -> bool: + obj = self.get_object_by_identifier(identifier) + if obj is not None: + self.remove_object(identifier) + return True + return False + + def get_array_metadata( + self, proxy: str | Uri | Any, path_in_external: str | None = None + ) -> DataArrayMetadata | List[DataArrayMetadata] | None: + array = self.read_array(proxy=proxy, path_in_external=path_in_external) + if array is not None: + if isinstance(array, np.ndarray): + return DataArrayMetadata.from_numpy_array(path_in_resource=path_in_external, array=array) + elif isinstance(array, list): + return DataArrayMetadata.from_list(path_in_resource=path_in_external, data=array) + def dumps_epc_content_and_files_lists(self) -> str: """ Dumps the EPC content and files lists for debugging purposes. @@ -782,6 +826,13 @@ def dumps_epc_content_and_files_lists(self) -> str: return "EPC Content:\n" + "\n".join(content_list) + "\n\nRaw Files:\n" + "\n".join(raw_files_list) + def close(self) -> None: + """ + Close the EPC file and release any resources. + :return: + """ + pass + # ______ __ ____ __ _ # / ____/___ ___ _________ ___ ______ ___ / / / __/_ ______ _____/ /_(_)___ ____ _____ diff --git a/energyml-utils/src/energyml/utils/epc_stream.py b/energyml-utils/src/energyml/utils/epc_stream.py index 721f9d6..6c8686a 100644 --- a/energyml-utils/src/energyml/utils/epc_stream.py +++ b/energyml-utils/src/energyml/utils/epc_stream.py @@ -16,17 +16,24 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Any, Iterator, Union, Tuple +from typing import Dict, List, Optional, Any, Iterator, Union, Tuple, TypedDict from weakref import WeakValueDictionary from energyml.opc.opc import Types, Override, CoreProperties, Relationships, Relationship from energyml.utils.data.datasets_io import HDF5FileReader, HDF5FileWriter +from energyml.utils.storage_interface import DataArrayMetadata, EnergymlStorageInterface, ResourceMetadata from energyml.utils.uri import Uri, parse_uri -from energyml.utils.workspace import EnergymlWorkspace +import h5py import numpy as np -from .constants import EPCRelsRelationshipType, OptimizedRegex, EpcExportVersion -from .epc import Epc, gen_energyml_object_path, gen_rels_path, get_epc_content_type_path -from .introspection import ( +from energyml.utils.constants import ( + EPCRelsRelationshipType, + OptimizedRegex, + EpcExportVersion, + content_type_to_qualified_type, +) +from energyml.utils.epc import Epc, gen_energyml_object_path, gen_rels_path, get_epc_content_type_path + +from energyml.utils.introspection import ( get_class_from_content_type, get_obj_content_type, get_obj_identifier, @@ -36,8 +43,29 @@ get_obj_type, get_obj_usable_class, ) -from .serialization import read_energyml_xml_bytes, serialize_xml +from energyml.utils.serialization import read_energyml_xml_bytes, serialize_xml from .xml import is_energyml_content_type +from enum import Enum + + +class RelsUpdateMode(Enum): + """ + Relationship update modes for EPC file management. + + UPDATE_AT_MODIFICATION: Maintain relationships in real-time as objects are added/removed/modified. + This provides the best consistency but may be slower for bulk operations. + + UPDATE_ON_CLOSE: Rebuild all relationships when closing the EPC file. + This is more efficient for bulk operations but relationships are only + consistent after closing. + + MANUAL: No automatic relationship updates. User must manually call rebuild_all_rels(). + This provides maximum control and performance for advanced use cases. + """ + + UPDATE_AT_MODIFICATION = "update_at_modification" + UPDATE_ON_CLOSE = "update_on_close" + MANUAL = "manual" @dataclass(frozen=True) @@ -48,8 +76,8 @@ class EpcObjectMetadata: object_type: str content_type: str file_path: str - version: Optional[str] = None identifier: Optional[str] = None + version: Optional[str] = None def __post_init__(self): if self.identifier is None: @@ -79,126 +107,222 @@ def memory_efficiency(self) -> float: return (1 - (self.loaded_objects / self.total_objects)) * 100 if self.total_objects > 0 else 100.0 -class EpcStreamReader(EnergymlWorkspace): +# =========================================================================================== +# PARALLEL PROCESSING WORKER FUNCTIONS +# =========================================================================================== + +# Configuration constants for parallel processing +_MIN_OBJECTS_PER_WORKER = 10 # Minimum objects to justify spawning a worker +_WORKER_POOL_SIZE_RATIO = 10 # Number of objects per worker process + + +class _WorkerResult(TypedDict): + """Type definition for parallel worker function return value.""" + + identifier: str + object_type: str + source_rels: List[Dict[str, str]] + dor_targets: List[Tuple[str, str]] + + +def _process_object_for_rels_worker(args: Tuple[str, str, Dict[str, EpcObjectMetadata]]) -> Optional[_WorkerResult]: """ - Memory-efficient EPC file reader with lazy loading and smart caching. + Worker function for parallel relationship processing (runs in separate process). - This class provides the same interface as the standard Epc class but loads - objects on-demand rather than keeping everything in memory. Perfect for - handling very large EPC files with thousands of objects. + This function is executed in a separate process to compute SOURCE relationships + for a single object. It bypasses Python's GIL for CPU-intensive XML parsing. - Features: - - Lazy loading: Objects loaded only when accessed - - Smart caching: LRU cache with configurable size - - Memory monitoring: Track memory usage and cache efficiency - - Streaming validation: Validate objects without full loading - - Batch operations: Efficient bulk operations - - Context management: Automatic resource cleanup + Performance characteristics: + - Each worker process opens its own ZIP file handle + - XML parsing happens independently on separate CPU cores + - Results are serialized back to the main process via pickle - Performance optimizations: - - Pre-compiled regex patterns for 15-75% faster parsing - - Weak references to prevent memory leaks - - Compressed metadata storage - - Efficient ZIP file handling + Args: + args: Tuple containing: + - identifier: Object UUID/identifier to process + - epc_file_path: Absolute path to the EPC file + - metadata_dict: Dictionary of all object metadata (for validation) + + Returns: + Dictionary conforming to _WorkerResult TypedDict, or None if processing fails. """ + identifier, epc_file_path, metadata_dict = args + + try: + # Open ZIP file in this worker process + import zipfile + from energyml.utils.serialization import read_energyml_xml_bytes + from energyml.utils.introspection import ( + get_direct_dor_list, + get_obj_identifier, + get_obj_type, + get_obj_usable_class, + ) + from energyml.utils.constants import EPCRelsRelationshipType + from energyml.utils.introspection import get_class_from_content_type - def __init__( - self, - epc_file_path: Union[str, Path], - cache_size: int = 100, - validate_on_load: bool = True, - preload_metadata: bool = True, - export_version: EpcExportVersion = EpcExportVersion.CLASSIC, - force_h5_path: Optional[str] = None, - ): - """ - Initialize the EPC stream reader. + metadata = metadata_dict.get(identifier) + if not metadata: + return None - Args: - epc_file_path: Path to the EPC file - cache_size: Maximum number of objects to keep in memory cache - validate_on_load: Whether to validate objects when loading - preload_metadata: Whether to preload all object metadata - export_version: EPC packaging version (CLASSIC or EXPANDED) - force_h5_path: Optional forced HDF5 file path for external resources. If set, all arrays will be read/written from/to this path. - """ - self.epc_file_path = Path(epc_file_path) - self.cache_size = cache_size - self.validate_on_load = validate_on_load - self.force_h5_path = force_h5_path + # Load object from ZIP + with zipfile.ZipFile(epc_file_path, "r") as zf: + obj_data = zf.read(metadata.file_path) + obj_class = get_class_from_content_type(metadata.content_type) + obj = read_energyml_xml_bytes(obj_data, obj_class) - is_new_file = False + # Extract object type (cached to avoid reloading in Phase 3) + obj_type = get_obj_type(get_obj_usable_class(obj)) - # Validate file exists and is readable - if not self.epc_file_path.exists(): - logging.info(f"EPC file not found: {epc_file_path}. Creating a new empty EPC file.") - self._create_empty_epc() - is_new_file = True - # raise FileNotFoundError(f"EPC file not found: {epc_file_path}") + # Get all Data Object References (DORs) from this object + data_object_references = get_direct_dor_list(obj) - if not zipfile.is_zipfile(self.epc_file_path): - raise ValueError(f"File is not a valid ZIP/EPC file: {epc_file_path}") + # Build SOURCE relationships and track referenced objects + source_rels = [] + dor_targets = [] # Track (target_id, target_type) for reverse references - # Check if the ZIP file has the required EPC structure - if not is_new_file: + for dor in data_object_references: try: - with zipfile.ZipFile(self.epc_file_path, "r") as zf: - content_types_path = get_epc_content_type_path() - if content_types_path not in zf.namelist(): - logging.info(f"EPC file is missing required structure. Initializing empty EPC file.") - self._create_empty_epc() - is_new_file = True + target_identifier = get_obj_identifier(dor) + if target_identifier not in metadata_dict: + continue + + target_metadata = metadata_dict[target_identifier] + + # Extract target type (needed for relationship ID) + target_type = get_obj_type(get_obj_usable_class(dor)) + dor_targets.append((target_identifier, target_type)) + + # Serialize relationship as dict (Relationship objects aren't picklable) + rel_dict = { + "target": target_metadata.file_path, + "type_value": EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), + "id": f"_{identifier}_{target_type}_{target_identifier}", + } + source_rels.append(rel_dict) + except Exception as e: - logging.warning(f"Failed to check EPC structure: {e}. Reinitializing.") + # Don't fail entire object processing for one bad DOR + logging.debug(f"Skipping invalid DOR in {identifier}: {e}") - # Object metadata storage - self._metadata: Dict[str, EpcObjectMetadata] = {} # identifier -> metadata - self._uuid_index: Dict[str, List[str]] = {} # uuid -> list of identifiers - self._type_index: Dict[str, List[str]] = {} # object_type -> list of identifiers + return { + "identifier": identifier, + "object_type": obj_type, + "source_rels": source_rels, + "dor_targets": dor_targets, + } - # Caching system using weak references - self._object_cache: WeakValueDictionary = WeakValueDictionary() - self._access_order: List[str] = [] # LRU tracking + except Exception as e: + logging.warning(f"Worker failed to process {identifier}: {e}") + return None - # Core properties and stats - self._core_props: Optional[CoreProperties] = None - self.stats = EpcStreamingStats() - # File handle management - self._zip_file: Optional[zipfile.ZipFile] = None +# =========================================================================================== +# HELPER CLASSES FOR REFACTORED ARCHITECTURE +# =========================================================================================== - # EPC export version detection - self.export_version: EpcExportVersion = export_version or EpcExportVersion.CLASSIC # Default - # Additional rels management - self.additional_rels: Dict[str, List[Relationship]] = {} +class _ZipFileAccessor: + """ + Internal helper class for managing ZIP file access with proper resource management. - # Initialize by loading metadata - if not is_new_file and preload_metadata: - self._load_metadata() - # Detect EPC version after loading metadata - self.export_version = self._detect_epc_version() + This class handles: + - Persistent ZIP connections when keep_open=True + - On-demand connections when keep_open=False + - Proper cleanup and resource management + - Connection pooling for better performance + """ - def _create_empty_epc(self) -> None: - """Create an empty EPC file structure.""" - # Ensure directory exists - self.epc_file_path.parent.mkdir(parents=True, exist_ok=True) + def __init__(self, epc_file_path: Path, keep_open: bool = False): + """ + Initialize the ZIP file accessor. - with zipfile.ZipFile(self.epc_file_path, "w") as zf: - # Create [Content_Types].xml - content_types = Types() - content_types_xml = serialize_xml(content_types) - zf.writestr(get_epc_content_type_path(), content_types_xml) + Args: + epc_file_path: Path to the EPC file + keep_open: If True, maintains a persistent connection + """ + self.epc_file_path = epc_file_path + self.keep_open = keep_open + self._persistent_zip: Optional[zipfile.ZipFile] = None - # Create _rels/.rels - rels = Relationships() - rels_xml = serialize_xml(rels) - zf.writestr("_rels/.rels", rels_xml) + def open_persistent_connection(self) -> None: + """Open a persistent ZIP connection if keep_open is enabled.""" + if self.keep_open and self._persistent_zip is None: + self._persistent_zip = zipfile.ZipFile(self.epc_file_path, "r") - def _load_metadata(self) -> None: + @contextmanager + def get_zip_file(self) -> Iterator[zipfile.ZipFile]: + """ + Context manager for ZIP file access with proper resource management. + + If keep_open is True, uses the persistent connection. Otherwise opens a new one. + """ + if self.keep_open and self._persistent_zip is not None: + # Use persistent connection, don't close it + yield self._persistent_zip + else: + # Open and close per request + zf = None + try: + zf = zipfile.ZipFile(self.epc_file_path, "r") + yield zf + finally: + if zf is not None: + zf.close() + + def reopen_persistent_zip(self) -> None: + """Reopen persistent ZIP file after modifications to reflect changes.""" + if self.keep_open and self._persistent_zip is not None: + try: + self._persistent_zip.close() + except Exception: + pass + self._persistent_zip = zipfile.ZipFile(self.epc_file_path, "r") + + def close(self) -> None: + """Close the persistent ZIP file if it's open.""" + if self._persistent_zip is not None: + try: + self._persistent_zip.close() + except Exception as e: + logging.debug(f"Error closing persistent ZIP file: {e}") + finally: + self._persistent_zip = None + + +class _MetadataManager: + """ + Internal helper class for managing object metadata, indexing, and queries. + + This class handles: + - Loading metadata from [Content_Types].xml + - Maintaining UUID and type indexes + - Fast metadata queries without loading objects + - Version detection + """ + + def __init__(self, zip_accessor: _ZipFileAccessor, stats: EpcStreamingStats): + """ + Initialize the metadata manager. + + Args: + zip_accessor: ZIP file accessor for reading from EPC + stats: Statistics tracker + """ + self.zip_accessor = zip_accessor + self.stats = stats + + # Object metadata storage + self._metadata: Dict[str, EpcObjectMetadata] = {} # identifier -> metadata + self._uuid_index: Dict[str, List[str]] = {} # uuid -> list of identifiers + self._type_index: Dict[str, List[str]] = {} # object_type -> list of identifiers + self._core_props: Optional[CoreProperties] = None + self._core_props_path: Optional[str] = None + + def load_metadata(self) -> None: """Load object metadata from [Content_Types].xml without loading actual objects.""" try: - with self._get_zip_file() as zf: + with self.zip_accessor.get_zip_file() as zf: # Read content types content_types = self._read_content_types(zf) @@ -216,17 +340,6 @@ def _load_metadata(self) -> None: logging.error(f"Failed to load metadata from EPC file: {e}") raise - @contextmanager - def _get_zip_file(self) -> Iterator[zipfile.ZipFile]: - """Context manager for ZIP file access with proper resource management.""" - zf = None - try: - zf = zipfile.ZipFile(self.epc_file_path, "r") - yield zf - finally: - if zf is not None: - zf.close() - def _read_content_types(self, zf: zipfile.ZipFile) -> Types: """Read and parse [Content_Types].xml file.""" content_types_path = get_epc_content_type_path() @@ -282,11 +395,7 @@ def _process_energyml_object_metadata(self, zf: zipfile.ZipFile, override: Overr def _extract_object_info_fast( self, zf: zipfile.ZipFile, file_path: str, content_type: str ) -> Tuple[Optional[str], Optional[str], str]: - """ - Fast extraction of UUID and version from XML without full parsing. - - Uses optimized regex patterns for performance. - """ + """Fast extraction of UUID and version from XML without full parsing.""" try: # Read only the beginning of the file for UUID extraction with zf.open(file_path) as f: @@ -317,75 +426,815 @@ def _extract_object_info_fast( version = str(version) break - # Extract object type from content type - obj_type = self._extract_object_type_from_content_type(content_type) + # Extract object type from content type + obj_type = self._extract_object_type_from_content_type(content_type) + + return uuid, version, obj_type + + except Exception as e: + logging.debug(f"Fast extraction failed for {file_path}: {e}") + return None, None, "Unknown" + + def _extract_object_type_from_content_type(self, content_type: str) -> str: + """Extract object type from content type string.""" + try: + match = OptimizedRegex.CONTENT_TYPE.search(content_type) + if match: + return match.group("type") + except (AttributeError, KeyError): + pass + return "Unknown" + + def _is_core_properties(self, content_type: str) -> bool: + """Check if content type is CoreProperties.""" + return content_type == "application/vnd.openxmlformats-package.core-properties+xml" + + def _process_core_properties_metadata(self, override: Override) -> None: + """Process core properties metadata.""" + if override.part_name: + self._core_props_path = override.part_name.lstrip("/") + + def get_metadata(self, identifier: str) -> Optional[EpcObjectMetadata]: + """Get metadata for an object by identifier.""" + return self._metadata.get(identifier) + + def get_by_uuid(self, uuid: str) -> List[str]: + """Get all identifiers for objects with the given UUID.""" + return self._uuid_index.get(uuid, []) + + def get_by_type(self, object_type: str) -> List[str]: + """Get all identifiers for objects of the given type.""" + return self._type_index.get(object_type, []) + + def list_metadata(self, object_type: Optional[str] = None) -> List[EpcObjectMetadata]: + """List metadata for all objects, optionally filtered by type.""" + if object_type is None: + return list(self._metadata.values()) + return [self._metadata[identifier] for identifier in self._type_index.get(object_type, [])] + + def add_metadata(self, metadata: EpcObjectMetadata) -> None: + """Add metadata for a new object.""" + identifier = metadata.identifier + if identifier: + self._metadata[identifier] = metadata + + # Update UUID index + if metadata.uuid not in self._uuid_index: + self._uuid_index[metadata.uuid] = [] + self._uuid_index[metadata.uuid].append(identifier) + + # Update type index + if metadata.object_type not in self._type_index: + self._type_index[metadata.object_type] = [] + self._type_index[metadata.object_type].append(identifier) + + self.stats.total_objects += 1 + + def remove_metadata(self, identifier: str) -> Optional[EpcObjectMetadata]: + """Remove metadata for an object. Returns the removed metadata.""" + metadata = self._metadata.pop(identifier, None) + if metadata: + # Update UUID index + if metadata.uuid in self._uuid_index: + self._uuid_index[metadata.uuid].remove(identifier) + if not self._uuid_index[metadata.uuid]: + del self._uuid_index[metadata.uuid] + + # Update type index + if metadata.object_type in self._type_index: + self._type_index[metadata.object_type].remove(identifier) + if not self._type_index[metadata.object_type]: + del self._type_index[metadata.object_type] + + self.stats.total_objects -= 1 + + return metadata + + def contains(self, identifier: str) -> bool: + """Check if an object with the given identifier exists.""" + return identifier in self._metadata + + def __len__(self) -> int: + """Return total number of objects.""" + return len(self._metadata) + + def __iter__(self) -> Iterator[str]: + """Iterate over object identifiers.""" + return iter(self._metadata.keys()) + + def gen_rels_path_from_metadata(self, metadata: EpcObjectMetadata) -> str: + """Generate rels path from object metadata without loading the object.""" + obj_path = metadata.file_path + # Extract folder and filename from the object path + if "/" in obj_path: + obj_folder = obj_path[: obj_path.rindex("/") + 1] + obj_file_name = obj_path[obj_path.rindex("/") + 1 :] + else: + obj_folder = "" + obj_file_name = obj_path + + return f"{obj_folder}_rels/{obj_file_name}.rels" + + def gen_rels_path_from_identifier(self, identifier: str) -> Optional[str]: + """Generate rels path from object identifier without loading the object.""" + metadata = self._metadata.get(identifier) + if metadata is None: + return None + return self.gen_rels_path_from_metadata(metadata) + + def get_core_properties(self) -> Optional[CoreProperties]: + """Get core properties (loaded lazily).""" + if self._core_props is None and self._core_props_path: + try: + with self.zip_accessor.get_zip_file() as zf: + core_data = zf.read(self._core_props_path) + self.stats.bytes_read += len(core_data) + self._core_props = read_energyml_xml_bytes(core_data, CoreProperties) + except Exception as e: + logging.error(f"Failed to load core properties: {e}") + + return self._core_props + + def detect_epc_version(self) -> EpcExportVersion: + """Detect EPC packaging version based on file structure.""" + try: + with self.zip_accessor.get_zip_file() as zf: + file_list = zf.namelist() + + # Look for patterns that indicate EXPANDED version + for file_path in file_list: + # Skip metadata files + if ( + file_path.startswith("[Content_Types]") + or file_path.startswith("_rels/") + or file_path.endswith(".rels") + ): + continue + + # Check for namespace_ prefix pattern + if file_path.startswith("namespace_"): + path_parts = file_path.split("/") + if len(path_parts) >= 2: + logging.info(f"Detected EXPANDED EPC version based on path: {file_path}") + return EpcExportVersion.EXPANDED + + # If no EXPANDED patterns found, assume CLASSIC + logging.info("Detected CLASSIC EPC version") + return EpcExportVersion.CLASSIC + + except Exception as e: + logging.warning(f"Failed to detect EPC version, defaulting to CLASSIC: {e}") + return EpcExportVersion.CLASSIC + + def update_content_types_xml( + self, source_zip: zipfile.ZipFile, metadata: EpcObjectMetadata, add: bool = True + ) -> str: + """Update [Content_Types].xml to add or remove object entry. + + Args: + source_zip: Open ZIP file to read from + metadata: Object metadata + add: If True, add entry; if False, remove entry + + Returns: + Updated [Content_Types].xml as string + """ + # Read existing content types + content_types = self._read_content_types(source_zip) + + if add: + # Add new override entry + new_override = Override() + new_override.part_name = f"/{metadata.file_path}" + new_override.content_type = metadata.content_type + content_types.override.append(new_override) + else: + # Remove override entry + content_types.override = [ + override for override in content_types.override if override.part_name != f"/{metadata.file_path}" + ] + + # Serialize back to XML + return serialize_xml(content_types) + + +class _RelationshipManager: + """ + Internal helper class for managing relationships between objects. + + This class handles: + - Reading relationships from .rels files + - Writing relationship updates + - Supporting 3 update modes (UPDATE_AT_MODIFICATION, UPDATE_ON_CLOSE, MANUAL) + - Preserving EXTERNAL_RESOURCE relationships + - Rebuilding all relationships + """ + + def __init__( + self, + zip_accessor: _ZipFileAccessor, + metadata_manager: _MetadataManager, + stats: EpcStreamingStats, + export_version: EpcExportVersion, + rels_update_mode: RelsUpdateMode, + ): + """ + Initialize the relationship manager. + + Args: + zip_accessor: ZIP file accessor for reading/writing + metadata_manager: Metadata manager for object lookups + stats: Statistics tracker + export_version: EPC export version + rels_update_mode: Relationship update mode + """ + self.zip_accessor = zip_accessor + self.metadata_manager = metadata_manager + self.stats = stats + self.export_version = export_version + self.rels_update_mode = rels_update_mode + + # Additional rels management (for user-added relationships) + self.additional_rels: Dict[str, List[Relationship]] = {} + + def get_obj_rels(self, obj_identifier: str, rels_path: Optional[str] = None) -> List[Relationship]: + """ + Get all relationships for a given object. + Merges relationships from the EPC file with in-memory additional relationships. + """ + rels = [] + + # Read rels from EPC file + if rels_path is None: + rels_path = self.metadata_manager.gen_rels_path_from_identifier(obj_identifier) + + if rels_path is not None: + with self.zip_accessor.get_zip_file() as zf: + try: + rels_data = zf.read(rels_path) + self.stats.bytes_read += len(rels_data) + relationships = read_energyml_xml_bytes(rels_data, Relationships) + rels.extend(relationships.relationship) + except KeyError: + # No rels file found for this object + pass + + # Merge with in-memory additional relationships + if obj_identifier in self.additional_rels: + rels.extend(self.additional_rels[obj_identifier]) + + return rels + + def update_rels_for_new_object(self, obj: Any, obj_identifier: str) -> None: + """Update relationships when a new object is added (UPDATE_AT_MODIFICATION mode).""" + metadata = self.metadata_manager.get_metadata(obj_identifier) + if not metadata: + logging.warning(f"Metadata not found for {obj_identifier}") + return + + # Get all objects this new object references + direct_dors = get_direct_dor_list(obj) + + # Build SOURCE relationships for this object + source_relationships = [] + dest_updates: Dict[str, Relationship] = {} + + for dor in direct_dors: + try: + target_identifier = get_obj_identifier(dor) + if not self.metadata_manager.contains(target_identifier): + continue + + target_metadata = self.metadata_manager.get_metadata(target_identifier) + if not target_metadata: + continue + + # Create SOURCE relationship + source_rel = Relationship( + target=target_metadata.file_path, + type_value=EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), + id=f"_{obj_identifier}_{get_obj_type(get_obj_usable_class(dor))}_{target_identifier}", + ) + source_relationships.append(source_rel) + + # Create DESTINATION relationship + dest_rel = Relationship( + target=metadata.file_path, + type_value=EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), + id=f"_{target_identifier}_{get_obj_type(get_obj_usable_class(obj))}_{obj_identifier}", + ) + dest_updates[target_identifier] = dest_rel + + except Exception as e: + logging.warning(f"Failed to create relationship for DOR: {e}") + + # Write updates + self.write_rels_updates(obj_identifier, source_relationships, dest_updates) + + def update_rels_for_modified_object(self, obj: Any, obj_identifier: str, old_dors: List[Any]) -> None: + """Update relationships when an object is modified (UPDATE_AT_MODIFICATION mode).""" + metadata = self.metadata_manager.get_metadata(obj_identifier) + if not metadata: + logging.warning(f"Metadata not found for {obj_identifier}") + return + + # Get new DORs + new_dors = get_direct_dor_list(obj) + + # Convert to sets of identifiers for comparison + old_dor_ids = { + get_obj_identifier(dor) for dor in old_dors if self.metadata_manager.contains(get_obj_identifier(dor)) + } + new_dor_ids = { + get_obj_identifier(dor) for dor in new_dors if self.metadata_manager.contains(get_obj_identifier(dor)) + } + + # Find added and removed references + added_dor_ids = new_dor_ids - old_dor_ids + removed_dor_ids = old_dor_ids - new_dor_ids + + # Build new SOURCE relationships + source_relationships = [] + dest_updates: Dict[str, Relationship] = {} + + # Create relationships for all new DORs + for dor in new_dors: + target_identifier = get_obj_identifier(dor) + if not self.metadata_manager.contains(target_identifier): + continue + + target_metadata = self.metadata_manager.get_metadata(target_identifier) + if not target_metadata: + continue + + # SOURCE relationship + source_rel = Relationship( + target=target_metadata.file_path, + type_value=EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), + id=f"_{obj_identifier}_{get_obj_type(get_obj_usable_class(dor))}_{target_identifier}", + ) + source_relationships.append(source_rel) + + # DESTINATION relationship (for added DORs only) + if target_identifier in added_dor_ids: + dest_rel = Relationship( + target=metadata.file_path, + type_value=EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), + id=f"_{target_identifier}_{get_obj_type(get_obj_usable_class(obj))}_{obj_identifier}", + ) + dest_updates[target_identifier] = dest_rel + + # For removed DORs, remove DESTINATION relationships + removals: Dict[str, str] = {} + for removed_id in removed_dor_ids: + removals[removed_id] = f"_{removed_id}_.*_{obj_identifier}" + + # Write updates + self.write_rels_updates(obj_identifier, source_relationships, dest_updates, removals) + + def update_rels_for_removed_object(self, obj_identifier: str, obj: Optional[Any] = None) -> None: + """Update relationships when an object is removed (UPDATE_AT_MODIFICATION mode).""" + if obj is None: + # Object must be provided for removal + logging.warning(f"Cannot update rels for removed object {obj_identifier}: object not provided") + return + + # Get all objects this object references + direct_dors = get_direct_dor_list(obj) + + # Build removal patterns for DESTINATION relationships + removals: Dict[str, str] = {} + for dor in direct_dors: + try: + target_identifier = get_obj_identifier(dor) + if not self.metadata_manager.contains(target_identifier): + continue + + removals[target_identifier] = f"_{target_identifier}_.*_{obj_identifier}" + + except Exception as e: + logging.warning(f"Failed to process DOR for removal: {e}") + + # Write updates + self.write_rels_updates(obj_identifier, [], {}, removals, delete_source_rels=True) + + def write_rels_updates( + self, + source_identifier: str, + source_relationships: List[Relationship], + dest_updates: Dict[str, Relationship], + removals: Optional[Dict[str, str]] = None, + delete_source_rels: bool = False, + ) -> None: + """Write relationship updates to the EPC file efficiently.""" + import re + + removals = removals or {} + rels_updates: Dict[str, str] = {} + files_to_delete: List[str] = [] + + with self.zip_accessor.get_zip_file() as zf: + # 1. Handle source object's rels file + if not delete_source_rels: + source_rels_path = self.metadata_manager.gen_rels_path_from_identifier(source_identifier) + if source_rels_path: + # Read existing rels (excluding SOURCE_OBJECT type) + existing_rels = [] + try: + if source_rels_path in zf.namelist(): + rels_data = zf.read(source_rels_path) + existing_rels_obj = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels_obj and existing_rels_obj.relationship: + # Keep only non-SOURCE relationships + existing_rels = [ + r + for r in existing_rels_obj.relationship + if r.type_value != EPCRelsRelationshipType.SOURCE_OBJECT.get_type() + ] + except Exception: + pass + + # Combine with new SOURCE relationships + all_rels = existing_rels + source_relationships + if all_rels: + rels_updates[source_rels_path] = serialize_xml(Relationships(relationship=all_rels)) + elif source_rels_path in zf.namelist() and not all_rels: + files_to_delete.append(source_rels_path) + else: + # Mark source rels file for deletion + source_rels_path = self.metadata_manager.gen_rels_path_from_identifier(source_identifier) + if source_rels_path: + files_to_delete.append(source_rels_path) + + # 2. Handle destination updates + for target_identifier, dest_rel in dest_updates.items(): + target_rels_path = self.metadata_manager.gen_rels_path_from_identifier(target_identifier) + if not target_rels_path: + continue + + # Read existing rels + existing_rels = [] + try: + if target_rels_path in zf.namelist(): + rels_data = zf.read(target_rels_path) + existing_rels_obj = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels_obj and existing_rels_obj.relationship: + existing_rels = list(existing_rels_obj.relationship) + except Exception: + pass + + # Add new DESTINATION relationship if not already present + rel_exists = any( + r.target == dest_rel.target and r.type_value == dest_rel.type_value for r in existing_rels + ) + + if not rel_exists: + existing_rels.append(dest_rel) + rels_updates[target_rels_path] = serialize_xml(Relationships(relationship=existing_rels)) + + # 3. Handle removals + for target_identifier, pattern in removals.items(): + target_rels_path = self.metadata_manager.gen_rels_path_from_identifier(target_identifier) + if not target_rels_path: + continue + + # Read existing rels + existing_rels = [] + try: + if target_rels_path in zf.namelist(): + rels_data = zf.read(target_rels_path) + existing_rels_obj = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels_obj and existing_rels_obj.relationship: + existing_rels = list(existing_rels_obj.relationship) + except Exception: + pass + + # Filter out relationships matching the pattern + regex = re.compile(pattern) + filtered_rels = [r for r in existing_rels if not (r.id and regex.match(r.id))] + + if len(filtered_rels) != len(existing_rels): + if filtered_rels: + rels_updates[target_rels_path] = serialize_xml(Relationships(relationship=filtered_rels)) + else: + files_to_delete.append(target_rels_path) + + # Write updates to EPC file + if rels_updates or files_to_delete: + with tempfile.NamedTemporaryFile(delete=False, suffix=".epc") as temp_file: + temp_path = temp_file.name + + try: + with self.zip_accessor.get_zip_file() as source_zf: + with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zf: + # Copy all files except those to delete or update + files_to_skip = set(files_to_delete) + for item in source_zf.infolist(): + if item.filename not in files_to_skip and item.filename not in rels_updates: + data = source_zf.read(item.filename) + target_zf.writestr(item, data) + + # Write updated rels files + for rels_path, rels_xml in rels_updates.items(): + target_zf.writestr(rels_path, rels_xml) + + # Replace original + shutil.move(temp_path, self.zip_accessor.epc_file_path) + self.zip_accessor.reopen_persistent_zip() + + except Exception as e: + if os.path.exists(temp_path): + os.unlink(temp_path) + logging.error(f"Failed to write rels updates: {e}") + raise + + def compute_object_rels(self, obj: Any, obj_identifier: str) -> List[Relationship]: + """ + Compute relationships for a given object (SOURCE relationships). + This object references other objects through DORs. + + Args: + obj: The EnergyML object + obj_identifier: The identifier of the object + + Returns: + List of Relationship objects for this object's .rels file + """ + rels = [] + + # Get all DORs (Data Object References) in this object + direct_dors = get_direct_dor_list(obj) + + for dor in direct_dors: + try: + target_identifier = get_obj_identifier(dor) + + # Get target file path from metadata without processing DOR + # The relationship target should be the object's file path, not its rels path + if self.metadata_manager.contains(target_identifier): + target_metadata = self.metadata_manager.get_metadata(target_identifier) + if target_metadata: + target_path = target_metadata.file_path + else: + target_path = gen_energyml_object_path(dor, self.export_version) + else: + # Fall back to generating path from DOR if metadata not found + target_path = gen_energyml_object_path(dor, self.export_version) + + # Create SOURCE relationship (this object -> target object) + rel = Relationship( + target=target_path, + type_value=EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), + id=f"_{obj_identifier}_{get_obj_type(get_obj_usable_class(dor))}_{target_identifier}", + ) + rels.append(rel) + except Exception as e: + logging.warning(f"Failed to create relationship for DOR in {obj_identifier}: {e}") + + return rels + + def merge_rels(self, new_rels: List[Relationship], existing_rels: List[Relationship]) -> List[Relationship]: + """Merge new relationships with existing ones, avoiding duplicates and ensuring unique IDs. + + Args: + new_rels: New relationships to add + existing_rels: Existing relationships + + Returns: + Merged list of relationships + """ + merged = list(existing_rels) + + for new_rel in new_rels: + # Check if relationship already exists + rel_exists = any(r.target == new_rel.target and r.type_value == new_rel.type_value for r in merged) + + if not rel_exists: + # Ensure unique ID + cpt = 0 + new_rel_id = new_rel.id + while any(r.id == new_rel_id for r in merged): + new_rel_id = f"{new_rel.id}_{cpt}" + cpt += 1 + if new_rel_id != new_rel.id: + new_rel.id = new_rel_id + + merged.append(new_rel) + + return merged + + +# =========================================================================================== +# MAIN CLASS (REFACTORED TO USE HELPER CLASSES) +# =========================================================================================== + + +class EpcStreamReader(EnergymlStorageInterface): + """ + Memory-efficient EPC file reader with lazy loading and smart caching. + + This class provides the same interface as the standard Epc class but loads + objects on-demand rather than keeping everything in memory. Perfect for + handling very large EPC files with thousands of objects. + + Features: + - Lazy loading: Objects loaded only when accessed + - Smart caching: LRU cache with configurable size + - Memory monitoring: Track memory usage and cache efficiency + - Streaming validation: Validate objects without full loading + - Batch operations: Efficient bulk operations + - Context management: Automatic resource cleanup + - Flexible relationship management: Three modes for updating object relationships + + Relationship Update Modes: + - UPDATE_AT_MODIFICATION: Maintains relationships in real-time as objects are added/removed/modified. + Best for maintaining consistency but may be slower for bulk operations. + - UPDATE_ON_CLOSE: Rebuilds all relationships when closing the EPC file (default). + More efficient for bulk operations but relationships only consistent after closing. + - MANUAL: No automatic relationship updates. User must manually call rebuild_all_rels(). + Maximum control and performance for advanced use cases. + + Performance optimizations: + - Pre-compiled regex patterns for 15-75% faster parsing + - Weak references to prevent memory leaks + - Compressed metadata storage + - Efficient ZIP file handling + """ + + def __init__( + self, + epc_file_path: Union[str, Path], + cache_size: int = 100, + validate_on_load: bool = True, + preload_metadata: bool = True, + export_version: EpcExportVersion = EpcExportVersion.CLASSIC, + force_h5_path: Optional[str] = None, + keep_open: bool = False, + force_title_load: bool = False, + rels_update_mode: RelsUpdateMode = RelsUpdateMode.UPDATE_ON_CLOSE, + enable_parallel_rels: bool = False, + parallel_worker_ratio: int = 10, + ): + """ + Initialize the EPC stream reader. + + Args: + epc_file_path: Path to the EPC file + cache_size: Maximum number of objects to keep in memory cache + validate_on_load: Whether to validate objects when loading + preload_metadata: Whether to preload all object metadata + export_version: EPC packaging version (CLASSIC or EXPANDED) + force_h5_path: Optional forced HDF5 file path for external resources. If set, all arrays will be read/written from/to this path. + keep_open: If True, keeps the ZIP file open for better performance with multiple operations. File is closed only when instance is deleted or close() is called. + force_title_load: If True, forces loading object titles when listing objects (may impact performance) + rels_update_mode: Mode for updating relationships (UPDATE_AT_MODIFICATION, UPDATE_ON_CLOSE, or MANUAL) + enable_parallel_rels: If True, uses parallel processing for rebuild_all_rels() operations (faster for large EPCs) + parallel_worker_ratio: Number of objects per worker process (default: 10). Lower values = more workers. Only used when enable_parallel_rels=True. + """ + # Public attributes + self.epc_file_path = Path(epc_file_path) + self.enable_parallel_rels = enable_parallel_rels + self.parallel_worker_ratio = parallel_worker_ratio + self.cache_size = cache_size + self.validate_on_load = validate_on_load + self.force_h5_path = force_h5_path + self.cache_opened_h5 = None + self.keep_open = keep_open + self.force_title_load = force_title_load + self.rels_update_mode = rels_update_mode + self.export_version: EpcExportVersion = export_version or EpcExportVersion.CLASSIC + self.stats = EpcStreamingStats() + + # Caching system using weak references + self._object_cache: WeakValueDictionary = WeakValueDictionary() + self._access_order: List[str] = [] # LRU tracking + + is_new_file = False + + # Validate file exists and is readable + if not self.epc_file_path.exists(): + logging.info(f"EPC file not found: {epc_file_path}. Creating a new empty EPC file.") + self._create_empty_epc() + is_new_file = True + + if not zipfile.is_zipfile(self.epc_file_path): + raise ValueError(f"File is not a valid ZIP/EPC file: {epc_file_path}") + + # Check if the ZIP file has the required EPC structure + if not is_new_file: + try: + with zipfile.ZipFile(self.epc_file_path, "r") as zf: + content_types_path = get_epc_content_type_path() + if content_types_path not in zf.namelist(): + logging.info("EPC file is missing required structure. Initializing empty EPC file.") + self._create_empty_epc() + is_new_file = True + except Exception as e: + logging.warning(f"Failed to check EPC structure: {e}. Reinitializing.") + + # Initialize helper classes (internal architecture) + self._zip_accessor = _ZipFileAccessor(self.epc_file_path, keep_open=keep_open) + self._metadata_mgr = _MetadataManager(self._zip_accessor, self.stats) + self._rels_mgr = _RelationshipManager( + self._zip_accessor, self._metadata_mgr, self.stats, self.export_version, rels_update_mode + ) + + # Initialize by loading metadata + if not is_new_file and preload_metadata: + self._metadata_mgr.load_metadata() + # Detect EPC version after loading metadata + self.export_version = self._metadata_mgr.detect_epc_version() + # Update relationship manager's export version + self._rels_mgr.export_version = self.export_version + + # Open persistent ZIP connection if keep_open is enabled + if keep_open and not is_new_file: + self._zip_accessor.open_persistent_connection() + + # Backward compatibility: expose internal structures as properties + # This allows existing code to access _metadata, _uuid_index, etc. + self._metadata = self._metadata_mgr._metadata + self._uuid_index = self._metadata_mgr._uuid_index + self._type_index = self._metadata_mgr._type_index + self.additional_rels = self._rels_mgr.additional_rels + + def _create_empty_epc(self) -> None: + """Create an empty EPC file structure.""" + # Ensure directory exists + self.epc_file_path.parent.mkdir(parents=True, exist_ok=True) + + with zipfile.ZipFile(self.epc_file_path, "w") as zf: + # Create [Content_Types].xml + content_types = Types() + content_types_xml = serialize_xml(content_types) + zf.writestr(get_epc_content_type_path(), content_types_xml) + + # Create _rels/.rels + rels = Relationships() + rels_xml = serialize_xml(rels) + zf.writestr("_rels/.rels", rels_xml) + + def _load_metadata(self) -> None: + """Load object metadata from [Content_Types].xml without loading actual objects.""" + # Delegate to metadata manager + self._metadata_mgr.load_metadata() + + def _read_content_types(self, zf: zipfile.ZipFile) -> Types: + """Read and parse [Content_Types].xml file.""" + # Delegate to metadata manager + return self._metadata_mgr._read_content_types(zf) - return uuid, version, obj_type + def _process_energyml_object_metadata(self, zf: zipfile.ZipFile, override: Override) -> None: + """Process metadata for an EnergyML object without loading it.""" + # Delegate to metadata manager + self._metadata_mgr._process_energyml_object_metadata(zf, override) - except Exception as e: - logging.debug(f"Fast extraction failed for {file_path}: {e}") - return None, None, "Unknown" + def _extract_object_info_fast( + self, zf: zipfile.ZipFile, file_path: str, content_type: str + ) -> Tuple[Optional[str], Optional[str], str]: + """Fast extraction of UUID and version from XML without full parsing.""" + # Delegate to metadata manager + return self._metadata_mgr._extract_object_info_fast(zf, file_path, content_type) def _extract_object_type_from_content_type(self, content_type: str) -> str: """Extract object type from content type string.""" - try: - match = OptimizedRegex.CONTENT_TYPE.search(content_type) - if match: - return match.group("type") - except (AttributeError, KeyError): - pass - return "Unknown" + # Delegate to metadata manager + return self._metadata_mgr._extract_object_type_from_content_type(content_type) def _is_core_properties(self, content_type: str) -> bool: """Check if content type is CoreProperties.""" - return content_type == "application/vnd.openxmlformats-package.core-properties+xml" + # Delegate to metadata manager + return self._metadata_mgr._is_core_properties(content_type) def _process_core_properties_metadata(self, override: Override) -> None: """Process core properties metadata.""" - # Store core properties path for lazy loading - if override.part_name: - self._core_props_path = override.part_name.lstrip("/") + # Delegate to metadata manager + self._metadata_mgr._process_core_properties_metadata(override) def _detect_epc_version(self) -> EpcExportVersion: - """ - Detect EPC packaging version based on file structure. - - CLASSIC version uses simple flat structure: obj_Type_UUID.xml - EXPANDED version uses namespace structure: namespace_pkg/UUID/version_X/Type_UUID.xml - - Returns: - EpcExportVersion: The detected version (CLASSIC or EXPANDED) - """ - try: - with self._get_zip_file() as zf: - file_list = zf.namelist() + """Detect EPC packaging version based on file structure.""" + # Delegate to metadata manager + return self._metadata_mgr.detect_epc_version() - # Look for patterns that indicate EXPANDED version - # EXPANDED uses paths like: namespace_resqml22/UUID/version_X/Type_UUID.xml - for file_path in file_list: - # Skip metadata files - if ( - file_path.startswith("[Content_Types]") - or file_path.startswith("_rels/") - or file_path.endswith(".rels") - ): - continue + def _gen_rels_path_from_metadata(self, metadata: EpcObjectMetadata) -> str: + """Generate rels path from object metadata without loading the object.""" + # Delegate to metadata manager + return self._metadata_mgr.gen_rels_path_from_metadata(metadata) - # Check for namespace_ prefix pattern - if file_path.startswith("namespace_"): - # Further validate it's the EXPANDED structure - path_parts = file_path.split("/") - if len(path_parts) >= 2: # namespace_pkg/filename or namespace_pkg/version_x/filename - logging.info(f"Detected EXPANDED EPC version based on path: {file_path}") - return EpcExportVersion.EXPANDED + def _gen_rels_path_from_identifier(self, identifier: str) -> Optional[str]: + """Generate rels path from object identifier without loading the object.""" + # Delegate to metadata manager + return self._metadata_mgr.gen_rels_path_from_identifier(identifier) - # If no EXPANDED patterns found, assume CLASSIC - logging.info("Detected CLASSIC EPC version") - return EpcExportVersion.CLASSIC + @contextmanager + def _get_zip_file(self) -> Iterator[zipfile.ZipFile]: + """Context manager for ZIP file access with proper resource management. - except Exception as e: - logging.warning(f"Failed to detect EPC version, defaulting to CLASSIC: {e}") - return EpcExportVersion.CLASSIC + If keep_open is True, uses the persistent connection. Otherwise opens a new one. + """ + # Delegate to the ZIP accessor helper class + with self._zip_accessor.get_zip_file() as zf: + yield zf def get_object_by_identifier(self, identifier: Union[str, Uri]) -> Optional[Any]: """ @@ -507,6 +1356,9 @@ def get_object_by_uuid(self, uuid: str) -> List[Any]: return objects + def get_object(self, identifier: Union[str, Uri]) -> Optional[Any]: + return self.get_object_by_identifier(identifier) + def get_objects_by_type(self, object_type: str) -> List[Any]: """Get all objects of the specified type.""" if object_type not in self._type_index: @@ -539,6 +1391,88 @@ def get_statistics(self) -> EpcStreamingStats: """Get current streaming statistics.""" return self.stats + def list_objects( + self, dataspace: Optional[str] = None, object_type: Optional[str] = None + ) -> List[ResourceMetadata]: + """ + List all objects with metadata (EnergymlStorageInterface method). + + Args: + dataspace: Optional dataspace filter (ignored for EPC files) + object_type: Optional type filter (qualified type) + + Returns: + List of ResourceMetadata for all matching objects + """ + + results = [] + metadata_list = self.list_object_metadata(object_type) + + for meta in metadata_list: + try: + # Load object to get title + title = "" + if self.force_title_load and meta.identifier: + obj = self.get_object_by_identifier(meta.identifier) + if obj and hasattr(obj, "citation") and obj.citation: + if hasattr(obj.citation, "title"): + title = obj.citation.title + + # Build URI + qualified_type = content_type_to_qualified_type(meta.content_type) + if meta.version: + uri = f"eml:///{qualified_type}(uuid={meta.uuid},version='{meta.version}')" + else: + uri = f"eml:///{qualified_type}({meta.uuid})" + + resource = ResourceMetadata( + uri=uri, + uuid=meta.uuid, + version=meta.version, + title=title, + object_type=meta.object_type, + content_type=meta.content_type, + ) + + results.append(resource) + except Exception: + continue + + return results + + def get_array_metadata( + self, proxy: Union[str, Uri, Any], path_in_external: Optional[str] = None + ) -> Union[DataArrayMetadata, List[DataArrayMetadata], None]: + """ + Get metadata for data array(s) (EnergymlStorageInterface method). + + Args: + proxy: The object identifier/URI or the object itself + path_in_external: Optional specific path + + Returns: + DataArrayMetadata if path specified, List[DataArrayMetadata] if no path, + or None if not found + """ + from energyml.utils.storage_interface import DataArrayMetadata + + try: + if path_in_external: + array = self.read_array(proxy, path_in_external) + if array is not None: + return DataArrayMetadata( + path_in_resource=path_in_external, + array_type=str(array.dtype), + dimensions=list(array.shape), + ) + else: + # Would need to scan all possible paths - not practical + return [] + except Exception: + pass + + return None + def preload_objects(self, identifiers: List[str]) -> int: """ Preload specific objects into cache. @@ -563,16 +1497,78 @@ def clear_cache(self) -> None: def get_core_properties(self) -> Optional[CoreProperties]: """Get core properties (loaded lazily).""" - if self._core_props is None and hasattr(self, "_core_props_path"): - try: - with self._get_zip_file() as zf: - core_data = zf.read(self._core_props_path) - self.stats.bytes_read += len(core_data) - self._core_props = read_energyml_xml_bytes(core_data, CoreProperties) - except Exception as e: - logging.error(f"Failed to load core properties: {e}") + # Delegate to metadata manager + return self._metadata_mgr.get_core_properties() - return self._core_props + def _gen_rels_path_from_metadata(self, metadata: EpcObjectMetadata) -> str: + """ + Generate rels path from object metadata without loading the object. + + Args: + metadata: Object metadata containing file path information + + Returns: + Path to the rels file for this object + """ + obj_path = metadata.file_path + # Extract folder and filename from the object path + if "/" in obj_path: + obj_folder = obj_path[: obj_path.rindex("/") + 1] + obj_file_name = obj_path[obj_path.rindex("/") + 1 :] + else: + obj_folder = "" + obj_file_name = obj_path + + return f"{obj_folder}_rels/{obj_file_name}.rels" + + def _gen_rels_path_from_identifier(self, identifier: str) -> Optional[str]: + """ + Generate rels path from object identifier without loading the object. + + Args: + identifier: Object identifier (uuid.version) + + Returns: + Path to the rels file, or None if metadata not found + """ + metadata = self._metadata.get(identifier) + if metadata is None: + return None + return self._gen_rels_path_from_metadata(metadata) + + def _update_rels_for_new_object(self, obj: Any, obj_identifier: str) -> None: + """Update relationships when a new object is added (UPDATE_AT_MODIFICATION mode).""" + # Delegate to relationship manager + self._rels_mgr.update_rels_for_new_object(obj, obj_identifier) + + def _update_rels_for_modified_object(self, obj: Any, obj_identifier: str, old_dors: List[Any]) -> None: + """Update relationships when an object is modified (UPDATE_AT_MODIFICATION mode).""" + # Delegate to relationship manager + self._rels_mgr.update_rels_for_modified_object(obj, obj_identifier, old_dors) + + def _update_rels_for_removed_object(self, obj_identifier: str, obj: Optional[Any] = None) -> None: + """Update relationships when an object is removed (UPDATE_AT_MODIFICATION mode).""" + # Delegate to relationship manager + self._rels_mgr.update_rels_for_removed_object(obj_identifier, obj) + + def _write_rels_updates( + self, + source_identifier: str, + source_relationships: List[Relationship], + dest_updates: Dict[str, Relationship], + removals: Optional[Dict[str, str]] = None, + delete_source_rels: bool = False, + ) -> None: + """Write relationship updates to the EPC file efficiently.""" + # Delegate to relationship manager + self._rels_mgr.write_rels_updates( + source_identifier, source_relationships, dest_updates, removals, delete_source_rels + ) + + def _reopen_persistent_zip(self) -> None: + """Reopen persistent ZIP file after modifications to reflect changes.""" + # Delegate to ZIP accessor + self._zip_accessor.reopen_persistent_zip() def to_epc(self, load_all: bool = False) -> Epc: """ @@ -599,33 +1595,86 @@ def to_epc(self, load_all: bool = False) -> Epc: return epc + def set_rels_update_mode(self, mode: RelsUpdateMode) -> None: + """ + Change the relationship update mode. + + Args: + mode: The new RelsUpdateMode to use + + Note: + Changing from MANUAL or UPDATE_ON_CLOSE to UPDATE_AT_MODIFICATION + may require calling rebuild_all_rels() first to ensure consistency. + """ + + def set_rels_update_mode(self, mode: RelsUpdateMode) -> None: + """ + Change the relationship update mode. + + Args: + mode: The new RelsUpdateMode to use + + Note: + Changing from MANUAL or UPDATE_ON_CLOSE to UPDATE_AT_MODIFICATION + may require calling rebuild_all_rels() first to ensure consistency. + """ + if not isinstance(mode, RelsUpdateMode): + raise ValueError(f"mode must be a RelsUpdateMode enum value, got {type(mode)}") + + old_mode = self.rels_update_mode + self.rels_update_mode = mode + # Also update the relationship manager + self._rels_mgr.rels_update_mode = mode + + logging.info(f"Changed relationship update mode from {old_mode.value} to {mode.value}") + + def get_rels_update_mode(self) -> RelsUpdateMode: + """ + Get the current relationship update mode. + + Returns: + The current RelsUpdateMode + """ + return self.rels_update_mode + def get_obj_rels(self, obj: Union[str, Uri, Any]) -> List[Relationship]: """ Get all relationships for a given object. + Merges relationships from the EPC file with in-memory additional relationships. + + Optimized to avoid loading the object when identifier/URI is provided. + :param obj: the object or its identifier/URI :return: list of Relationship objects """ - rels = [] + # Get identifier without loading the object + obj_identifier = None + rels_path = None - # read rels from EPC file if isinstance(obj, (str, Uri)): - obj = self.get_object_by_identifier(obj) - with zipfile.ZipFile(self.epc_file_path, "r") as zf: + # Convert URI to identifier if needed + if isinstance(obj, Uri) or parse_uri(obj) is not None: + uri = parse_uri(obj) if isinstance(obj, str) else obj + assert uri is not None and uri.uuid is not None + obj_identifier = uri.uuid + "." + (uri.version or "") + else: + obj_identifier = obj + + # Generate rels path from metadata without loading the object + rels_path = self._gen_rels_path_from_identifier(obj_identifier) + else: + # We have the actual object + obj_identifier = get_obj_identifier(obj) rels_path = gen_rels_path(obj, self.export_version) - try: - rels_data = zf.read(rels_path) - self.stats.bytes_read += len(rels_data) - relationships = read_energyml_xml_bytes(rels_data, Relationships) - rels.extend(relationships.relationship) - except KeyError: - # No rels file found for this object - pass - return rels + # Delegate to relationship manager + return self._rels_mgr.get_obj_rels(obj_identifier, rels_path) def get_h5_file_paths(self, obj: Union[str, Uri, Any]) -> List[str]: """ - Get all HDF5 file paths referenced in the EPC file (from rels to external resources) + Get all HDF5 file paths referenced in the EPC file (from rels to external resources). + Optimized to avoid loading the object when identifier/URI is provided. + :param obj: the object or its identifier/URI :return: list of HDF5 file paths """ @@ -633,13 +1682,44 @@ def get_h5_file_paths(self, obj: Union[str, Uri, Any]) -> List[str]: return [self.force_h5_path] h5_paths = set() + obj_identifier = None + rels_path = None + + # Get identifier and rels path without loading the object if isinstance(obj, (str, Uri)): - obj = self.get_object_by_identifier(obj) + # Convert URI to identifier if needed + if isinstance(obj, Uri) or parse_uri(obj) is not None: + uri = parse_uri(obj) if isinstance(obj, str) else obj + assert uri is not None and uri.uuid is not None + obj_identifier = uri.uuid + "." + (uri.version or "") + else: + obj_identifier = obj + + # Generate rels path from metadata without loading the object + rels_path = self._gen_rels_path_from_identifier(obj_identifier) + else: + # We have the actual object + obj_identifier = get_obj_identifier(obj) + rels_path = gen_rels_path(obj, self.export_version) - for rels in self.additional_rels.get(get_obj_identifier(obj), []): + # Check in-memory additional rels first + for rels in self.additional_rels.get(obj_identifier, []): if rels.type_value == EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(): h5_paths.add(rels.target) + # Also check rels from the EPC file + if rels_path is not None: + with self._get_zip_file() as zf: + try: + rels_data = zf.read(rels_path) + self.stats.bytes_read += len(rels_data) + relationships = read_energyml_xml_bytes(rels_data, Relationships) + for rel in relationships.relationship: + if rel.type_value == EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(): + h5_paths.add(rel.target) + except KeyError: + pass + if len(h5_paths) == 0: # search if an h5 file has the same name than the epc file epc_folder = os.path.dirname(self.epc_file_path) @@ -659,12 +1739,19 @@ def read_array(self, proxy: Union[str, Uri, Any], path_in_external: str) -> Opti :return: the dataset as a numpy array """ # Resolve proxy to object - if isinstance(proxy, (str, Uri)): - obj = self.get_object_by_identifier(proxy) + + h5_path = [] + if self.force_h5_path is not None: + if self.cache_opened_h5 is None: + self.cache_opened_h5 = h5py.File(self.force_h5_path, "a") + h5_path = [self.cache_opened_h5] else: - obj = proxy + if isinstance(proxy, (str, Uri)): + obj = self.get_object_by_identifier(proxy) + else: + obj = proxy - h5_path = self.get_h5_file_paths(obj) + h5_path = self.get_h5_file_paths(obj) h5_reader = HDF5FileReader() @@ -688,13 +1775,18 @@ def write_array(self, proxy: Union[str, Uri, Any], path_in_external: str, array: return: True if successful """ - # Resolve proxy to object - if isinstance(proxy, (str, Uri)): - obj = self.get_object_by_identifier(proxy) + h5_path = [] + if self.force_h5_path is not None: + if self.cache_opened_h5 is None: + self.cache_opened_h5 = h5py.File(self.force_h5_path, "a") + h5_path = [self.cache_opened_h5] else: - obj = proxy + if isinstance(proxy, (str, Uri)): + obj = self.get_object_by_identifier(proxy) + else: + obj = proxy - h5_path = self.get_h5_file_paths(obj) + h5_path = self.get_h5_file_paths(obj) h5_writer = HDF5FileWriter() @@ -743,7 +1835,7 @@ def validate_all_objects(self, fast_mode: bool = True) -> Dict[str, List[str]]: return results - def get_object_dependencies(self, identifier: str) -> List[str]: + def get_object_dependencies(self, identifier: Union[str, Uri]) -> List[str]: """ Get list of object identifiers that this object depends on. @@ -772,6 +1864,55 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit with cleanup.""" self.clear_cache() + self.close() + if self.cache_opened_h5 is not None: + try: + self.cache_opened_h5.close() + except Exception: + pass + self.cache_opened_h5 = None + + def __del__(self): + """Destructor to ensure persistent ZIP file is closed.""" + try: + self.close() + if self.cache_opened_h5 is not None: + try: + self.cache_opened_h5.close() + except Exception: + pass + self.cache_opened_h5 = None + except Exception: + pass # Ignore errors during cleanup + + def close(self) -> None: + """Close the persistent ZIP file if it's open, recomputing rels first if mode is UPDATE_ON_CLOSE.""" + # Recompute all relationships before closing if in UPDATE_ON_CLOSE mode + if self.rels_update_mode == RelsUpdateMode.UPDATE_ON_CLOSE: + try: + self.rebuild_all_rels(clean_first=True) + logging.info("Rebuilt all relationships on close (UPDATE_ON_CLOSE mode)") + except Exception as e: + logging.warning(f"Error rebuilding rels on close: {e}") + + # Delegate to ZIP accessor + self._zip_accessor.close() + + def put_object(self, obj: Any, dataspace: Optional[str] = None) -> Optional[str]: + """ + Store an energyml object (EnergymlStorageInterface method). + + Args: + obj: The energyml object to store + dataspace: Optional dataspace name (ignored for EPC files) + + Returns: + The identifier of the stored object (UUID.version or UUID), or None on error + """ + try: + return self.add_object(obj, replace_if_exists=True) + except Exception: + return None def add_object(self, obj: Any, file_path: Optional[str] = None, replace_if_exists: bool = True) -> str: """ @@ -854,6 +1995,10 @@ def add_object(self, obj: Any, file_path: Optional[str] = None, replace_if_exist # Save changes to file self._add_object_to_file(obj, metadata) + # Update relationships if in UPDATE_AT_MODIFICATION mode + if self.rels_update_mode == RelsUpdateMode.UPDATE_AT_MODIFICATION: + self._update_rels_for_new_object(obj, identifier) + # Update stats self.stats.total_objects += 1 @@ -867,6 +2012,18 @@ def add_object(self, obj: Any, file_path: Optional[str] = None, replace_if_exist self._rollback_add_object(identifier) raise RuntimeError(f"Failed to add object to EPC: {e}") + def delete_object(self, identifier: Union[str, Uri]) -> bool: + """ + Delete an object by its identifier (EnergymlStorageInterface method). + + Args: + identifier: Object identifier (UUID or UUID.version) or ETP URI + + Returns: + True if successfully deleted, False otherwise + """ + return self.remove_object(identifier) + def remove_object(self, identifier: Union[str, Uri]) -> bool: """ Remove an object (or all versions of an object) from the EPC file and update caches. @@ -913,13 +2070,27 @@ def remove_object(self, identifier: Union[str, Uri]) -> bool: raise RuntimeError(f"Failed to remove object from EPC: {e}") def _remove_single_object(self, identifier: str) -> bool: - """Remove a single object by its full identifier.""" + """ + Remove a single object by its full identifier. + + Args: + identifier: The full identifier (uuid.version) of the object to remove + Returns: + True if the object was successfully removed, False otherwise + """ try: if identifier not in self._metadata: return False metadata = self._metadata[identifier] + # If in UPDATE_AT_MODIFICATION mode, update rels before removing + obj = None + if self.rels_update_mode == RelsUpdateMode.UPDATE_AT_MODIFICATION: + obj = self.get_object_by_identifier(identifier) + if obj: + self._update_rels_for_removed_object(identifier, obj) + # IMPORTANT: Remove from file FIRST (before clearing cache/metadata) # because _remove_object_from_file needs to load the object to access its DORs self._remove_object_from_file(metadata) @@ -976,11 +2147,114 @@ def update_object(self, obj: Any) -> str: raise ValueError("Object must have a valid identifier and exist in the EPC file") try: - # Remove existing object - self.remove_object(identifier) + # If in UPDATE_AT_MODIFICATION mode, get old DORs and handle update differently + if self.rels_update_mode == RelsUpdateMode.UPDATE_AT_MODIFICATION: + old_obj = self.get_object_by_identifier(identifier) + old_dors = get_direct_dor_list(old_obj) if old_obj else [] + + # Preserve non-SOURCE/DESTINATION relationships (like EXTERNAL_RESOURCE) before removal + preserved_rels = [] + try: + obj_rels = self.get_obj_rels(identifier) + preserved_rels = [ + r + for r in obj_rels + if r.type_value + not in ( + EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), + EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), + ) + ] + except Exception: + pass + + # Remove existing object (without rels update since we're replacing it) + # Temporarily switch to MANUAL mode to avoid double updates + original_mode = self.rels_update_mode + self.rels_update_mode = RelsUpdateMode.MANUAL + self.remove_object(identifier) + self.rels_update_mode = original_mode + + # Add updated object (without rels update since we'll do custom update) + self.rels_update_mode = RelsUpdateMode.MANUAL + new_identifier = self.add_object(obj) + self.rels_update_mode = original_mode + + # Now do the specialized update that handles both adds and removes + self._update_rels_for_modified_object(obj, new_identifier, old_dors) + + # Restore preserved relationships (like EXTERNAL_RESOURCE) + if preserved_rels: + # These need to be written directly to the rels file + # since _update_rels_for_modified_object already wrote it + rels_path = self._gen_rels_path_from_identifier(new_identifier) + if rels_path: + with self._get_zip_file() as zf: + # Read current rels + current_rels = [] + try: + if rels_path in zf.namelist(): + rels_data = zf.read(rels_path) + rels_obj = read_energyml_xml_bytes(rels_data, Relationships) + if rels_obj and rels_obj.relationship: + current_rels = list(rels_obj.relationship) + except Exception: + pass + + # Add preserved rels + all_rels = current_rels + preserved_rels + + # Write back + with tempfile.NamedTemporaryFile(delete=False, suffix=".epc") as temp_file: + temp_path = temp_file.name + + try: + with self._get_zip_file() as source_zf: + with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zf: + # Copy all files except the rels file we're updating + for item in source_zf.infolist(): + if item.filename != rels_path: + buffer = source_zf.read(item.filename) + target_zf.writestr(item, buffer) + + # Write updated rels file + target_zf.writestr( + rels_path, serialize_xml(Relationships(relationship=all_rels)) + ) + + # Replace original + shutil.move(temp_path, self.epc_file_path) + self._reopen_persistent_zip() + + except Exception: + if os.path.exists(temp_path): + os.unlink(temp_path) + raise + + else: + # For other modes (UPDATE_ON_CLOSE, MANUAL), preserve non-SOURCE/DESTINATION relationships + preserved_rels = [] + try: + obj_rels = self.get_obj_rels(identifier) + preserved_rels = [ + r + for r in obj_rels + if r.type_value + not in ( + EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), + EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), + ) + ] + except Exception: + pass + + # Simple remove + add + self.remove_object(identifier) + new_identifier = self.add_object(obj) - # Add updated object - new_identifier = self.add_object(obj) + # Restore preserved relationships if any + if preserved_rels: + self.add_rels_for_object(new_identifier, preserved_rels, write_immediately=True) logging.info(f"Updated object {identifier} to {new_identifier} in EPC file") return new_identifier @@ -989,399 +2263,383 @@ def update_object(self, obj: Any) -> str: logging.error(f"Failed to update object {identifier}: {e}") raise RuntimeError(f"Failed to update object in EPC: {e}") - def add_rels_for_object(self, identifier: Union[str, Uri, Any], relationships: List[Relationship]) -> None: + def add_rels_for_object( + self, identifier: Union[str, Uri, Any], relationships: List[Relationship], write_immediately: bool = False + ) -> None: """ Add additional relationships for a specific object. + Relationships are stored in memory and can be written immediately or deferred + until write_pending_rels() is called, or when the EPC is closed. + Args: identifier: The identifier of the object, can be str, Uri, or the object itself relationships: List of Relationship objects to add + write_immediately: If True, writes pending rels to disk immediately after adding. + If False (default), rels are kept in memory for batching. """ is_uri = isinstance(identifier, Uri) or (isinstance(identifier, str) and parse_uri(identifier) is not None) - object_instance = None if is_uri: uri = parse_uri(identifier) if isinstance(identifier, str) else identifier assert uri is not None and uri.uuid is not None identifier = uri.uuid + "." + (uri.version or "") - object_instance = self.get_object_by_identifier(identifier) elif not isinstance(identifier, str): identifier = get_obj_identifier(identifier) - object_instance = self.get_object_by_identifier(identifier) - else: - object_instance = identifier assert isinstance(identifier, str) if identifier not in self.additional_rels: self.additional_rels[identifier] = [] - self.additional_rels[identifier].extend(relationships) - if len(self.additional_rels[identifier]) > 0: - # Create temporary file for updated EPC - with tempfile.NamedTemporaryFile(delete=False, suffix=".epc") as temp_file: - temp_path = temp_file.name - # Update the .rels file for this object by updating the rels file in the EPC - with ( - zipfile.ZipFile(self.epc_file_path, "r") as source_zip, - zipfile.ZipFile(temp_path, "a") as target_zip, - ): - # copy all files except the rels file to be updated - for item in source_zip.infolist(): - if item.filename != gen_rels_path(object_instance, self.export_version): - buffer = source_zip.read(item.filename) - target_zip.writestr(item, buffer) - - self._update_existing_rels_files( - Relationships(relationship=relationships), - gen_rels_path(object_instance, self.export_version), - source_zip, - target_zip, - ) - shutil.move(temp_path, self.epc_file_path) - - def _compute_object_rels(self, obj: Any, obj_identifier: str) -> List[Relationship]: - """ - Compute relationships for a given object (SOURCE relationships). - This object references other objects through DORs. - - Args: - obj: The EnergyML object - obj_identifier: The identifier of the object - - Returns: - List of Relationship objects for this object's .rels file - """ - rels = [] - - # Get all DORs (Data Object References) in this object - direct_dors = get_direct_dor_list(obj) - - for dor in direct_dors: - try: - target_identifier = get_obj_identifier(dor) - target_rels_path = gen_rels_path(dor, self.export_version) - - # Create SOURCE relationship (this object -> target object) - rel = Relationship( - target=target_rels_path, - type_value=EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), - id=f"_{obj_identifier}_{get_obj_type(get_obj_usable_class(dor))}_{target_identifier}", - ) - rels.append(rel) - except Exception as e: - logging.warning(f"Failed to create relationship for DOR in {obj_identifier}: {e}") + self.additional_rels[identifier].extend(relationships) + logging.debug(f"Added {len(relationships)} relationships for object {identifier} (in-memory)") - return rels + if write_immediately: + self.write_pending_rels() - def _get_objects_referencing(self, target_identifier: str) -> List[Tuple[str, Any]]: + def write_pending_rels(self) -> int: """ - Find all objects that reference the target object. + Write all pending in-memory relationships to the EPC file efficiently. - Args: - target_identifier: The identifier of the target object + This method reads existing rels, merges them in memory with pending rels, + then rewrites only the affected rels files in a single ZIP update. Returns: - List of tuples (identifier, object) of objects that reference the target + Number of rels files updated """ - referencing_objects = [] - - # We need to check all objects in the EPC to find those that reference our target - for identifier in self._metadata: - # Load the object to check its DORs - obj = self.get_object_by_identifier(identifier) - if obj is not None: - # Check if this object references our target - direct_dors = get_direct_dor_list(obj) - for dor in direct_dors: - try: - dor_identifier = get_obj_identifier(dor) - if dor_identifier == target_identifier: - referencing_objects.append((identifier, obj)) - break # Found a reference, no need to check other DORs in this object - except Exception: - continue + if not self.additional_rels: + logging.debug("No pending relationships to write") + return 0 - return referencing_objects + updated_count = 0 - def _update_existing_rels_files( - self, rels: Relationships, rel_path: str, source_zip: zipfile.ZipFile, target_zip: zipfile.ZipFile - ) -> None: - """Merge new relationships with existing .rels, reading from source and writing to target ZIP. - - Args: - rels: New Relationships to add - rel_path: Path to the .rels file - source_zip: ZIP to read existing rels from - target_zip: ZIP to write updated rels to - """ - # print("@ Updating rels file:", rel_path) - existing_relationships = [] - try: - if rel_path in source_zip.namelist(): - rels_data = source_zip.read(rel_path) - existing_rels = read_energyml_xml_bytes(rels_data, Relationships) - if existing_rels and existing_rels.relationship: - existing_relationships = list(existing_rels.relationship) - except Exception as e: - logging.debug(f"Could not read existing rels for {rel_path}: {e}") + # Step 1: Read existing rels and merge with pending rels in memory + merged_rels: Dict[str, Relationships] = {} # rels_path -> merged Relationships - for new_rel in rels.relationship: - rel_exists = any( - r.target == new_rel.target and r.type_value == new_rel.type_value for r in existing_relationships - ) - cpt = 0 - new_rel_id = new_rel.id - while any(r.id == new_rel_id for r in existing_relationships): - new_rel_id = f"{new_rel.id}_{cpt}" - cpt += 1 - if new_rel_id != new_rel.id: - new_rel.id = new_rel_id - if not rel_exists: - existing_relationships.append(new_rel) + with self._get_zip_file() as zf: + for obj_identifier, new_relationships in self.additional_rels.items(): + # Generate rels path from metadata without loading the object + rels_path = self._gen_rels_path_from_identifier(obj_identifier) + if rels_path is None: + logging.warning(f"Could not generate rels path for {obj_identifier}") + continue - if existing_relationships: - updated_rels = Relationships(relationship=existing_relationships) - updated_rels_xml = serialize_xml(updated_rels) - target_zip.writestr(rel_path, updated_rels_xml) + # Read existing rels from ZIP + existing_relationships = [] + try: + if rels_path in zf.namelist(): + rels_data = zf.read(rels_path) + existing_rels = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels and existing_rels.relationship: + existing_relationships = list(existing_rels.relationship) + except Exception as e: + logging.debug(f"Could not read existing rels for {rels_path}: {e}") + + # Merge new relationships, avoiding duplicates + for new_rel in new_relationships: + # Check if relationship already exists + rel_exists = any( + r.target == new_rel.target and r.type_value == new_rel.type_value + for r in existing_relationships + ) - def _update_rels_files( - self, - obj: Any, - metadata: EpcObjectMetadata, - source_zip: zipfile.ZipFile, - target_zip: zipfile.ZipFile, - ) -> List[str]: - """ - Update all necessary .rels files when adding/updating an object. + if not rel_exists: + # Ensure unique ID + cpt = 0 + new_rel_id = new_rel.id + while any(r.id == new_rel_id for r in existing_relationships): + new_rel_id = f"{new_rel.id}_{cpt}" + cpt += 1 + if new_rel_id != new_rel.id: + new_rel.id = new_rel_id - This includes: - 1. The object's own .rels file (for objects it references) - 2. The .rels files of objects that now reference this object (DESTINATION relationships) + existing_relationships.append(new_rel) - Args: - obj: The object being added/updated - metadata: Metadata for the object - source_zip: Source ZIP file to read existing rels from - target_zip: Target ZIP file to write updated rels to + # Store merged result + if existing_relationships: + merged_rels[rels_path] = Relationships(relationship=existing_relationships) - returns: - List of updated .rels file paths - """ - obj_identifier = metadata.identifier - updated_rels_paths = [] - if not obj_identifier: - logging.warning("Object identifier is None, skipping rels update") - return updated_rels_paths - - # 1. Create/update the object's own .rels file - obj_rels_path = gen_rels_path(obj, self.export_version) - obj_relationships = self._compute_object_rels(obj, obj_identifier) - - if obj_relationships: - self._update_existing_rels_files( - Relationships(relationship=obj_relationships), obj_rels_path, source_zip, target_zip - ) - updated_rels_paths.append(obj_rels_path) + # Step 2: Write updated rels back to ZIP (create temp, copy all, replace) + if not merged_rels: + return 0 - # 2. Update .rels files of objects referenced by this object - # These objects need DESTINATION relationships pointing to our object - direct_dors = get_direct_dor_list(obj) + with tempfile.NamedTemporaryFile(delete=False, suffix=".epc") as temp_file: + temp_path = temp_file.name - logging.debug(f"Updating rels for object {obj_identifier}, found {len(direct_dors)} direct DORs") + try: + # Copy entire ZIP, replacing only the updated rels files + with self._get_zip_file() as source_zf: + with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zf: + # Copy all files except the rels we're updating + for item in source_zf.infolist(): + if item.filename not in merged_rels: + buffer = source_zf.read(item.filename) + target_zf.writestr(item, buffer) + + # Write updated rels files + for rels_path, relationships in merged_rels.items(): + rels_xml = serialize_xml(relationships) + target_zf.writestr(rels_path, rels_xml) + updated_count += 1 + + # Replace original with updated ZIP + shutil.move(temp_path, self.epc_file_path) + self._reopen_persistent_zip() - for dor in direct_dors: - try: - target_rels_path = gen_rels_path(dor, self.export_version) - target_identifier = get_obj_identifier(dor) + # Clear pending rels after successful write + self.additional_rels.clear() - # Add DESTINATION relationship from target to our object - dest_rel = Relationship( - target=metadata.file_path, - type_value=EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), - id=f"_{target_identifier}_{get_obj_type(get_obj_usable_class(obj))}_{obj_identifier}", - ) + logging.info(f"Wrote {updated_count} rels files to EPC") + return updated_count - self._update_existing_rels_files( - Relationships(relationship=[dest_rel]), target_rels_path, source_zip, target_zip - ) - updated_rels_paths.append(target_rels_path) + except Exception as e: + if os.path.exists(temp_path): + os.unlink(temp_path) + logging.error(f"Failed to write pending rels: {e}") + raise - except Exception as e: - logging.warning(f"Failed to update rels for referenced object: {e}") - return updated_rels_paths + def _compute_object_rels(self, obj: Any, obj_identifier: str) -> List[Relationship]: + """Compute relationships for a given object (SOURCE relationships). - def _remove_rels_files( - self, obj: Any, metadata: EpcObjectMetadata, source_zip: zipfile.ZipFile, target_zip: zipfile.ZipFile - ) -> None: + Delegates to _rels_mgr.compute_object_rels() """ - Remove/update .rels files when removing an object. + return self._rels_mgr.compute_object_rels(obj, obj_identifier) - This includes: - 1. Removing the object's own .rels file - 2. Removing DESTINATION relationships from objects that this object referenced + def _merge_rels(self, new_rels: List[Relationship], existing_rels: List[Relationship]) -> List[Relationship]: + """Merge new relationships with existing ones, avoiding duplicates and ensuring unique IDs. - Args: - obj: The object being removed - metadata: Metadata for the object - source_zip: Source ZIP file to read existing rels from - target_zip: Target ZIP file to write updated rels to + Delegates to _rels_mgr.merge_rels() """ - # obj_identifier = metadata.identifier + return self._rels_mgr.merge_rels(new_rels, existing_rels) - # 1. The object's own .rels file will be automatically excluded by not copying it - # obj_rels_path = gen_rels_path(obj, self.export_version) + def _add_object_to_file(self, obj: Any, metadata: EpcObjectMetadata) -> None: + """Add object to the EPC file efficiently. - # 2. Update .rels files of objects that were referenced by this object - # Remove DESTINATION relationships that pointed to our object - direct_dors = get_direct_dor_list(obj) + Reads existing rels, computes updates in memory, then writes everything + in a single ZIP operation. + """ + xml_content = serialize_xml(obj) + obj_identifier = metadata.identifier + assert obj_identifier is not None, "Object identifier must not be None" - for dor in direct_dors: - try: - target_identifier = get_obj_identifier(dor) + # Step 1: Compute which rels files need to be updated and prepare their content + rels_updates: Dict[str, str] = {} # rels_path -> XML content - # Check if target object exists - if target_identifier not in self._metadata: - continue + with self._get_zip_file() as zf: + # 1a. Object's own .rels file + obj_rels_path = gen_rels_path(obj, self.export_version) + obj_relationships = self._compute_object_rels(obj, obj_identifier) - target_obj = self.get_object_by_identifier(target_identifier) - if target_obj is None: - continue + if obj_relationships: + # Read existing rels + existing_rels = [] + try: + if obj_rels_path in zf.namelist(): + rels_data = zf.read(obj_rels_path) + existing_rels_obj = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels_obj and existing_rels_obj.relationship: + existing_rels = list(existing_rels_obj.relationship) + except Exception: + pass - target_rels_path = gen_rels_path(target_obj, self.export_version) + # Merge and serialize + merged_rels = self._merge_rels(obj_relationships, existing_rels) + if merged_rels: + rels_updates[obj_rels_path] = serialize_xml(Relationships(relationship=merged_rels)) - # Read existing rels for the target object - existing_relationships = [] + # 1b. Update rels of referenced objects (DESTINATION relationships) + direct_dors = get_direct_dor_list(obj) + for dor in direct_dors: try: - if target_rels_path in source_zip.namelist(): - rels_data = source_zip.read(target_rels_path) - existing_rels = read_energyml_xml_bytes(rels_data, Relationships) - if existing_rels and existing_rels.relationship: - existing_relationships = list(existing_rels.relationship) - except Exception as e: - logging.debug(f"Could not read existing rels for {target_identifier}: {e}") - - # Remove DESTINATION relationship that pointed to our object - updated_relationships = [ - r - for r in existing_relationships - if not ( - r.target == metadata.file_path - and r.type_value == EPCRelsRelationshipType.DESTINATION_OBJECT.get_type() + target_identifier = get_obj_identifier(dor) + + # Generate rels path from metadata without processing DOR + target_rels_path = self._gen_rels_path_from_identifier(target_identifier) + if target_rels_path is None: + # Fall back to generating from DOR if metadata not found + target_rels_path = gen_rels_path(dor, self.export_version) + + # Create DESTINATION relationship + dest_rel = Relationship( + target=metadata.file_path, + type_value=EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), + id=f"_{target_identifier}_{get_obj_type(get_obj_usable_class(obj))}_{obj_identifier}", ) - ] - # Write updated rels file (or skip if no relationships left) - if updated_relationships: - updated_rels = Relationships(relationship=updated_relationships) - updated_rels_xml = serialize_xml(updated_rels) - target_zip.writestr(target_rels_path, updated_rels_xml) + # Read existing rels + existing_rels = [] + try: + if target_rels_path in zf.namelist(): + rels_data = zf.read(target_rels_path) + existing_rels_obj = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels_obj and existing_rels_obj.relationship: + existing_rels = list(existing_rels_obj.relationship) + except Exception: + pass - except Exception as e: - logging.warning(f"Failed to update rels for referenced object during removal: {e}") + # Merge and serialize + merged_rels = self._merge_rels([dest_rel], existing_rels) + if merged_rels: + rels_updates[target_rels_path] = serialize_xml(Relationships(relationship=merged_rels)) - def _add_object_to_file(self, obj: Any, metadata: EpcObjectMetadata) -> None: - """Add object to the EPC file by safely rewriting the ZIP archive. + except Exception as e: + logging.warning(f"Failed to prepare rels update for referenced object: {e}") - The method creates a temporary ZIP archive, copies all entries except - the ones to be updated (content types and relevant .rels), then writes - the new object, merges and writes updated .rels files and the - updated [Content_Types].xml before replacing the original file. This - avoids issues with append mode creating overlapped entries. - """ - xml_content = serialize_xml(obj) + # 1c. Update [Content_Types].xml + content_types_xml = self._update_content_types_xml(zf, metadata, add=True) - # Create temporary file for updated EPC + # Step 2: Write everything to new ZIP with tempfile.NamedTemporaryFile(delete=False, suffix=".epc") as temp_file: temp_path = temp_file.name try: - with zipfile.ZipFile(self.epc_file_path, "r") as source_zip: - with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zip: + with self._get_zip_file() as source_zf: + with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zf: + # Write new object + target_zf.writestr(metadata.file_path, xml_content) - # Add new object file - target_zip.writestr(metadata.file_path, xml_content) + # Write updated [Content_Types].xml + target_zf.writestr(get_epc_content_type_path(), content_types_xml) - # Update .rels files by merging with existing ones read from source - updated_rels_paths = self._update_rels_files(obj, metadata, source_zip, target_zip) + # Write updated rels files + for rels_path, rels_xml in rels_updates.items(): + target_zf.writestr(rels_path, rels_xml) - # Copy all existing files except [Content_Types].xml and rels we'll update - for item in source_zip.infolist(): - if item.filename == get_epc_content_type_path() or item.filename in updated_rels_paths: - continue - data = source_zip.read(item.filename) - target_zip.writestr(item, data) + # Copy all other files + files_to_skip = {get_epc_content_type_path(), metadata.file_path} + files_to_skip.update(rels_updates.keys()) - # Update [Content_Types].xml - updated_content_types = self._update_content_types_xml(source_zip, metadata, add=True) - target_zip.writestr(get_epc_content_type_path(), updated_content_types) + for item in source_zf.infolist(): + if item.filename not in files_to_skip: + buffer = source_zf.read(item.filename) + target_zf.writestr(item, buffer) - # Replace original file with updated version + # Replace original shutil.move(temp_path, self.epc_file_path) + self._reopen_persistent_zip() except Exception as e: - # Clean up temp file on error if os.path.exists(temp_path): os.unlink(temp_path) logging.error(f"Failed to add object to EPC file: {e}") raise def _remove_object_from_file(self, metadata: EpcObjectMetadata) -> None: - """Remove object from the EPC file by updating the ZIP archive. + """Remove object from the EPC file efficiently. - Note: This does NOT remove .rels files. Use clean_rels() to remove orphaned relationships. + Reads existing rels, computes updates in memory, then writes everything + in a single ZIP operation. Note: This does NOT remove .rels files. + Use clean_rels() to remove orphaned relationships. """ + # Load object first (needed to process its DORs) + if metadata.identifier is None: + logging.error("Cannot remove object with None identifier") + raise ValueError("Object identifier must not be None") - # Create temporary file for updated EPC + obj = self.get_object_by_identifier(metadata.identifier) + if obj is None: + logging.warning(f"Object {metadata.identifier} not found, cannot remove rels") + # Still proceed with removal even if object can't be loaded + + # Step 1: Compute rels updates (remove DESTINATION relationships from referenced objects) + rels_updates: Dict[str, str] = {} # rels_path -> XML content + + if obj is not None: + with self._get_zip_file() as zf: + direct_dors = get_direct_dor_list(obj) + + for dor in direct_dors: + try: + target_identifier = get_obj_identifier(dor) + if target_identifier not in self._metadata: + continue + + # Use metadata to generate rels path without loading the object + target_rels_path = self._gen_rels_path_from_identifier(target_identifier) + if target_rels_path is None: + continue + + # Read existing rels + existing_relationships = [] + try: + if target_rels_path in zf.namelist(): + rels_data = zf.read(target_rels_path) + existing_rels = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels and existing_rels.relationship: + existing_relationships = list(existing_rels.relationship) + except Exception as e: + logging.debug(f"Could not read existing rels for {target_identifier}: {e}") + + # Remove DESTINATION relationship that pointed to our object + updated_relationships = [ + r + for r in existing_relationships + if not ( + r.target == metadata.file_path + and r.type_value == EPCRelsRelationshipType.DESTINATION_OBJECT.get_type() + ) + ] + + # Only update if relationships remain + if updated_relationships: + rels_updates[target_rels_path] = serialize_xml( + Relationships(relationship=updated_relationships) + ) + + except Exception as e: + logging.warning(f"Failed to update rels for referenced object during removal: {e}") + + # Update [Content_Types].xml + content_types_xml = self._update_content_types_xml(zf, metadata, add=False) + else: + # If we couldn't load the object, still update content types + with self._get_zip_file() as zf: + content_types_xml = self._update_content_types_xml(zf, metadata, add=False) + + # Step 2: Write everything to new ZIP with tempfile.NamedTemporaryFile(delete=False, suffix=".epc") as temp_file: temp_path = temp_file.name try: - # Copy existing EPC to temp file, excluding the object to remove - with zipfile.ZipFile(self.epc_file_path, "r") as source_zip: - with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zip: - # Copy all existing files except the one to remove and [Content_Types].xml - # We keep .rels files as-is (they will be cleaned by clean_rels() if needed) - for item in source_zip.infolist(): - if item.filename not in [metadata.file_path, get_epc_content_type_path()]: - data = source_zip.read(item.filename) - target_zip.writestr(item, data) - - # Update [Content_Types].xml - updated_content_types = self._update_content_types_xml(source_zip, metadata, add=False) - target_zip.writestr(get_epc_content_type_path(), updated_content_types) - - # Replace original file with updated version + with self._get_zip_file() as source_zf: + with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zf: + # Write updated [Content_Types].xml + target_zf.writestr(get_epc_content_type_path(), content_types_xml) + + # Write updated rels files + for rels_path, rels_xml in rels_updates.items(): + target_zf.writestr(rels_path, rels_xml) + + # Copy all files except removed object, its rels, and files we're updating + obj_rels_path = self._gen_rels_path_from_metadata(metadata) + files_to_skip = {get_epc_content_type_path(), metadata.file_path} + if obj_rels_path: + files_to_skip.add(obj_rels_path) + files_to_skip.update(rels_updates.keys()) + + for item in source_zf.infolist(): + if item.filename not in files_to_skip: + buffer = source_zf.read(item.filename) + target_zf.writestr(item, buffer) + + # Replace original shutil.move(temp_path, self.epc_file_path) + self._reopen_persistent_zip() - except Exception: - # Clean up temp file on error + except Exception as e: if os.path.exists(temp_path): os.unlink(temp_path) + logging.error(f"Failed to remove object from EPC file: {e}") raise def _update_content_types_xml( self, source_zip: zipfile.ZipFile, metadata: EpcObjectMetadata, add: bool = True ) -> str: - """Update [Content_Types].xml to add or remove object entry.""" - # Read existing content types - content_types = self._read_content_types(source_zip) - - if add: - # Add new override entry - new_override = Override() - new_override.part_name = f"/{metadata.file_path}" - new_override.content_type = metadata.content_type - content_types.override.append(new_override) - else: - # Remove override entry - content_types.override = [ - override for override in content_types.override if override.part_name != f"/{metadata.file_path}" - ] - - # Serialize back to XML - from .serialization import serialize_xml + """Update [Content_Types].xml to add or remove object entry. - return serialize_xml(content_types) + Delegates to _metadata_mgr.update_content_types_xml() + """ + return self._metadata_mgr.update_content_types_xml(source_zip, metadata, add) def _rollback_add_object(self, identifier: Optional[str]) -> None: """Rollback changes made during failed add_object operation.""" @@ -1441,7 +2699,7 @@ def clean_rels(self) -> Dict[str, int]: temp_path = temp_file.name try: - with zipfile.ZipFile(self.epc_file_path, "r") as source_zip: + with self._get_zip_file() as source_zip: with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zip: # Get all existing object file paths for validation existing_object_files = {metadata.file_path for metadata in self._metadata.values()} @@ -1534,6 +2792,33 @@ def rebuild_all_rels(self, clean_first: bool = True) -> Dict[str, int]: 3. Analyzes its Data Object References (DORs) 4. Creates/updates .rels files with proper SOURCE and DESTINATION relationships + Args: + clean_first: If True, remove all existing .rels files before rebuilding + + Returns: + Dictionary with statistics: + - 'objects_processed': Number of objects analyzed + - 'rels_files_created': Number of .rels files created + - 'source_relationships': Number of SOURCE relationships created + - 'destination_relationships': Number of DESTINATION relationships created + - 'parallel_mode': True if parallel processing was used (optional key) + - 'execution_time': Execution time in seconds (optional key) + """ + if self.enable_parallel_rels: + return self._rebuild_all_rels_parallel(clean_first) + else: + return self._rebuild_all_rels_sequential(clean_first) + + def _rebuild_all_rels_sequential(self, clean_first: bool = True) -> Dict[str, int]: + """ + Rebuild all .rels files from scratch by analyzing all objects and their references. + + This method: + 1. Optionally cleans existing .rels files first + 2. Loads each object temporarily + 3. Analyzes its Data Object References (DORs) + 4. Creates/updates .rels files with proper SOURCE and DESTINATION relationships + Args: clean_first: If True, remove all existing .rels files before rebuilding @@ -1598,7 +2883,7 @@ def rebuild_all_rels(self, clean_first: bool = True) -> Dict[str, int]: continue # metadata = self._metadata[identifier] - obj_rels_path = gen_rels_path(obj, self.export_version) + obj_rels_path = self._gen_rels_path_from_identifier(identifier) # Get all DORs (objects this object references) dors = get_direct_dor_list(obj) @@ -1624,7 +2909,7 @@ def rebuild_all_rels(self, clean_first: bool = True) -> Dict[str, int]: except Exception as e: logging.debug(f"Failed to create SOURCE relationship: {e}") - if relationships: + if relationships and obj_rels_path: if obj_rels_path not in rels_files: rels_files[obj_rels_path] = Relationships(relationship=[]) rels_files[obj_rels_path].relationship.extend(relationships) @@ -1635,12 +2920,14 @@ def rebuild_all_rels(self, clean_first: bool = True) -> Dict[str, int]: # Add DESTINATION relationships for target_identifier, source_list in reverse_references.items(): try: - target_obj = self.get_object_by_identifier(target_identifier) - if target_obj is None: + if target_identifier not in self._metadata: continue target_metadata = self._metadata[target_identifier] - target_rels_path = gen_rels_path(target_obj, self.export_version) + target_rels_path = self._gen_rels_path_from_identifier(target_identifier) + + if not target_rels_path: + continue # Create DESTINATION relationships for each object that references this one for source_identifier, source_obj in source_list: @@ -1666,12 +2953,44 @@ def rebuild_all_rels(self, clean_first: bool = True) -> Dict[str, int]: stats["rels_files_created"] = len(rels_files) + # Before writing, preserve EXTERNAL_RESOURCE and other non-SOURCE/DESTINATION relationships + # This includes rels files that may not be in rels_files yet + with self._get_zip_file() as zf: + # Check all existing .rels files + for filename in zf.namelist(): + if not filename.endswith(".rels"): + continue + + try: + rels_data = zf.read(filename) + existing_rels_obj = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels_obj and existing_rels_obj.relationship: + # Preserve non-SOURCE/DESTINATION relationships (e.g., EXTERNAL_RESOURCE) + preserved_rels = [ + r + for r in existing_rels_obj.relationship + if r.type_value + not in ( + EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), + EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), + ) + ] + if preserved_rels: + if filename in rels_files: + # Add preserved relationships to existing entry + rels_files[filename].relationship = preserved_rels + rels_files[filename].relationship + else: + # Create new entry with only preserved relationships + rels_files[filename] = Relationships(relationship=preserved_rels) + except Exception as e: + logging.debug(f"Could not preserve existing rels from {filename}: {e}") + # Third pass: write the new EPC with updated .rels files with tempfile.NamedTemporaryFile(delete=False, suffix=".epc") as temp_file: temp_path = temp_file.name try: - with zipfile.ZipFile(self.epc_file_path, "r") as source_zip: + with self._get_zip_file() as source_zip: with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zip: # Copy all non-.rels files for item in source_zip.infolist(): @@ -1686,6 +3005,7 @@ def rebuild_all_rels(self, clean_first: bool = True) -> Dict[str, int]: # Replace original file shutil.move(temp_path, self.epc_file_path) + self._reopen_persistent_zip() logging.info( f"Rebuilt .rels files: processed {stats['objects_processed']} objects, " @@ -1702,6 +3022,218 @@ def rebuild_all_rels(self, clean_first: bool = True) -> Dict[str, int]: os.unlink(temp_path) raise RuntimeError(f"Failed to rebuild .rels files: {e}") + def _rebuild_all_rels_parallel(self, clean_first: bool = True) -> Dict[str, int]: + """ + Parallel implementation of rebuild_all_rels using multiprocessing. + + Strategy: + 1. Use multiprocessing.Pool to process objects in parallel + 2. Each worker loads an object and computes its SOURCE relationships + 3. Main process aggregates results and builds DESTINATION relationships + 4. Sequential write phase (ZIP writing must be sequential) + + This bypasses Python's GIL for CPU-intensive XML parsing and provides + significant speedup for large EPCs (tested with 80+ objects). + """ + import tempfile + import shutil + import time + from multiprocessing import Pool, cpu_count + + start_time = time.time() + + stats = { + "objects_processed": 0, + "rels_files_created": 0, + "source_relationships": 0, + "destination_relationships": 0, + "parallel_mode": True, + } + + num_objects = len(self._metadata) + logging.info(f"Starting PARALLEL rebuild of all .rels files for {num_objects} objects...") + + # Prepare work items for parallel processing + # Pass metadata as dict (serializable) instead of keeping references + metadata_dict = {k: v for k, v in self._metadata.items()} + work_items = [(identifier, str(self.epc_file_path), metadata_dict) for identifier in self._metadata] + + # Determine optimal number of workers based on available CPUs and workload + # Don't spawn more workers than CPUs; use user-configurable ratio for workload per worker + worker_ratio = self.parallel_worker_ratio if hasattr(self, "parallel_worker_ratio") else _WORKER_POOL_SIZE_RATIO + num_workers = min(cpu_count(), max(1, num_objects // worker_ratio)) + logging.info(f"Using {num_workers} worker processes for {num_objects} objects (ratio: {worker_ratio})") + + # ============================================================================ + # PHASE 1: PARALLEL - Compute SOURCE relationships across worker processes + # ============================================================================ + results = [] + with Pool(processes=num_workers) as pool: + results = pool.map(_process_object_for_rels_worker, work_items) + + # ============================================================================ + # PHASE 2: SEQUENTIAL - Aggregate worker results + # ============================================================================ + # Build data structures for subsequent phases: + # - reverse_references: Map target objects to their sources (for DESTINATION rels) + # - rels_files: Accumulate all relationships by file path + # - object_types: Cache object types to eliminate redundant loads in Phase 3 + reverse_references: Dict[str, List[Tuple[str, str]]] = {} + rels_files: Dict[str, Relationships] = {} + object_types: Dict[str, str] = {} + + for result in results: + if result is None: + continue + + identifier = result["identifier"] + obj_type = result["object_type"] + source_rels = result["source_rels"] + dor_targets = result["dor_targets"] + + # Cache object type + object_types[identifier] = obj_type + + stats["objects_processed"] += 1 + + # Convert dicts back to Relationship objects + if source_rels: + obj_rels_path = self._gen_rels_path_from_identifier(identifier) + if obj_rels_path: + relationships = [] + for rel_dict in source_rels: + rel = Relationship( + target=rel_dict["target"], + type_value=rel_dict["type_value"], + id=rel_dict["id"], + ) + relationships.append(rel) + stats["source_relationships"] += 1 + + if obj_rels_path not in rels_files: + rels_files[obj_rels_path] = Relationships(relationship=[]) + rels_files[obj_rels_path].relationship.extend(relationships) + + # Build reverse reference map for DESTINATION relationships + # dor_targets now contains (target_id, target_type) tuples + for target_identifier, target_type in dor_targets: + if target_identifier not in reverse_references: + reverse_references[target_identifier] = [] + reverse_references[target_identifier].append((identifier, obj_type)) + + # ============================================================================ + # PHASE 3: SEQUENTIAL - Create DESTINATION relationships (zero object loading!) + # ============================================================================ + # Use cached object types from Phase 2 to build DESTINATION relationships + # without reloading any objects. This optimization is critical for performance. + for target_identifier, source_list in reverse_references.items(): + try: + if target_identifier not in self._metadata: + continue + + target_rels_path = self._gen_rels_path_from_identifier(target_identifier) + + if not target_rels_path: + continue + + # Use cached object types instead of loading objects! + for source_identifier, source_type in source_list: + try: + source_metadata = self._metadata[source_identifier] + + # No object loading needed - we have all the type info from Phase 2! + rel = Relationship( + target=source_metadata.file_path, + type_value=EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), + id=f"_{target_identifier}_{source_type}_{source_identifier}", + ) + + if target_rels_path not in rels_files: + rels_files[target_rels_path] = Relationships(relationship=[]) + rels_files[target_rels_path].relationship.append(rel) + stats["destination_relationships"] += 1 + + except Exception as e: + logging.debug(f"Failed to create DESTINATION relationship: {e}") + + except Exception as e: + logging.warning(f"Failed to create DESTINATION rels for {target_identifier}: {e}") + + stats["rels_files_created"] = len(rels_files) + + # ============================================================================ + # PHASE 4: SEQUENTIAL - Preserve non-object relationships + # ============================================================================ + # Preserve EXTERNAL_RESOURCE and other non-standard relationship types + with self._get_zip_file() as zf: + for filename in zf.namelist(): + if not filename.endswith(".rels"): + continue + + try: + rels_data = zf.read(filename) + existing_rels_obj = read_energyml_xml_bytes(rels_data, Relationships) + if existing_rels_obj and existing_rels_obj.relationship: + preserved_rels = [ + r + for r in existing_rels_obj.relationship + if r.type_value + not in ( + EPCRelsRelationshipType.SOURCE_OBJECT.get_type(), + EPCRelsRelationshipType.DESTINATION_OBJECT.get_type(), + ) + ] + if preserved_rels: + if filename in rels_files: + rels_files[filename].relationship = preserved_rels + rels_files[filename].relationship + else: + rels_files[filename] = Relationships(relationship=preserved_rels) + except Exception as e: + logging.debug(f"Could not preserve existing rels from {filename}: {e}") + + # ============================================================================ + # PHASE 5: SEQUENTIAL - Write all relationships to ZIP file + # ============================================================================ + # ZIP file writing must be sequential (file format limitation) + with tempfile.NamedTemporaryFile(delete=False, suffix=".epc") as temp_file: + temp_path = temp_file.name + + try: + with self._get_zip_file() as source_zip: + with zipfile.ZipFile(temp_path, "w", zipfile.ZIP_DEFLATED) as target_zip: + # Copy all non-.rels files + for item in source_zip.infolist(): + if not (item.filename.endswith(".rels") and clean_first): + data = source_zip.read(item.filename) + target_zip.writestr(item, data) + + # Write new .rels files + for rels_path, rels_obj in rels_files.items(): + rels_xml = serialize_xml(rels_obj) + target_zip.writestr(rels_path, rels_xml) + + # Replace original file + shutil.move(temp_path, self.epc_file_path) + self._reopen_persistent_zip() + + execution_time = time.time() - start_time + stats["execution_time"] = execution_time + + logging.info( + f"Rebuilt .rels files (PARALLEL): processed {stats['objects_processed']} objects, " + f"created {stats['rels_files_created']} .rels files, " + f"added {stats['source_relationships']} SOURCE and " + f"{stats['destination_relationships']} DESTINATION relationships " + f"in {execution_time:.2f}s using {num_workers} workers" + ) + + return stats + + except Exception as e: + if os.path.exists(temp_path): + os.unlink(temp_path) + raise RuntimeError(f"Failed to rebuild .rels files (parallel): {e}") + def __repr__(self) -> str: """String representation.""" return ( @@ -1716,7 +3248,7 @@ def dumps_epc_content_and_files_lists(self): content_list = [] file_list = [] - with zipfile.ZipFile(self.epc_file_path, "r") as zf: + with self._get_zip_file() as zf: file_list = zf.namelist() for item in zf.infolist(): diff --git a/energyml-utils/src/energyml/utils/exception.py b/energyml-utils/src/energyml/utils/exception.py index 87e128c..fac041f 100644 --- a/energyml-utils/src/energyml/utils/exception.py +++ b/energyml-utils/src/energyml/utils/exception.py @@ -39,3 +39,10 @@ def __init__(self, t: Optional[str] = None): class UnparsableFile(Exception): def __init__(self, t: Optional[str] = None): super().__init__("File is not parsable for an EPC file. Please use RawFile class for non energyml files.") + + +class NotSupportedError(Exception): + """Exception for not supported features""" + + def __init__(self, msg): + super().__init__(msg) diff --git a/energyml-utils/src/energyml/utils/introspection.py b/energyml-utils/src/energyml/utils/introspection.py index e764eba..00408aa 100644 --- a/energyml-utils/src/energyml/utils/introspection.py +++ b/energyml-utils/src/energyml/utils/introspection.py @@ -233,6 +233,8 @@ def get_module_name_and_type_from_content_or_qualified_type(cqt: str) -> Tuple[s ct = parse_qualified_type(cqt) except AttributeError: pass + if ct is None: + raise ValueError(f"Cannot parse content-type or qualified-type: {cqt}") domain = ct.group("domain") if domain is None: @@ -281,6 +283,10 @@ def get_module_name(domain: str, domain_version: str): return f"energyml.{domain}.{domain_version}.{ns[ns.rindex('/') + 1:]}" +# Track modules that failed to import to avoid duplicate logging +_FAILED_IMPORT_MODULES = set() + + def import_related_module(energyml_module_name: str) -> None: """ Import related modules for a specific energyml module. (See. :const:`RELATED_MODULES`) @@ -292,8 +298,11 @@ def import_related_module(energyml_module_name: str) -> None: for m in related: try: import_module(m) - except Exception: - pass + except Exception as e: + # Only log once per unique module + if m not in _FAILED_IMPORT_MODULES: + _FAILED_IMPORT_MODULES.add(m) + logging.debug(f"Could not import related module {m}: {e}") # logging.error(e) @@ -425,6 +434,10 @@ def get_object_attribute(obj: Any, attr_dot_path: str, force_snake_case=True) -> """ current_attrib_name, path_next = path_next_attribute(attr_dot_path) + if current_attrib_name is None: + logging.error(f"Attribute path '{attr_dot_path}' is invalid.") + return None + if force_snake_case: current_attrib_name = snake_case(current_attrib_name) @@ -517,6 +530,10 @@ def get_object_attribute_or_create( """ current_attrib_name, path_next = path_next_attribute(attr_dot_path) + if current_attrib_name is None: + logging.error(f"Attribute path '{attr_dot_path}' is invalid.") + return None + if force_snake_case: current_attrib_name = snake_case(current_attrib_name) @@ -552,6 +569,10 @@ def get_object_attribute_advanced(obj: Any, attr_dot_path: str) -> Any: current_attrib_name = get_matching_class_attribute_name(obj, current_attrib_name) + if current_attrib_name is None: + logging.error(f"Attribute path '{attr_dot_path}' is invalid.") + return None + value = None if isinstance(obj, list): value = obj[int(current_attrib_name)] @@ -587,9 +608,10 @@ def get_object_attribute_no_verif(obj: Any, attr_name: str, default: Optional[An else: raise AttributeError(obj, name=attr_name) else: - return ( - getattr(obj, attr_name) or default - ) # we did not used the "default" of getattr to keep raising AttributeError + res = getattr(obj, attr_name) + if res is None: # we did not used the "default" of getattr to keep raising AttributeError + return default + return res def get_object_attribute_rgx(obj: Any, attr_dot_path_rgx: str) -> Any: @@ -870,6 +892,9 @@ def search_attribute_matching_name_with_path( # current_match = attrib_list[0] # next_match = ".".join(attrib_list[1:]) current_match, next_match = path_next_attribute(name_rgx) + if current_match is None: + logging.error(f"Attribute name regex '{name_rgx}' is invalid.") + return [] res = [] if current_path is None: @@ -997,7 +1022,7 @@ def set_attribute_from_dict(obj: Any, values: Dict) -> None: set_attribute_from_path(obj=obj, attribute_path=k, value=v) -def set_attribute_from_path(obj: Any, attribute_path: str, value: Any): +def set_attribute_from_path(obj: Any, attribute_path: str, value: Any) -> None: """ Changes the value of a (sub)attribute. Example : @@ -1023,6 +1048,11 @@ def set_attribute_from_path(obj: Any, attribute_path: str, value: Any): """ upper = obj current_attrib_name, path_next = path_next_attribute(attribute_path) + + if current_attrib_name is None: + logging.error(f"Attribute path '{attribute_path}' is invalid.") + return + if path_next is not None: set_attribute_from_path( get_object_attribute( @@ -1066,12 +1096,12 @@ def set_attribute_from_path(obj: Any, attribute_path: str, value: Any): setattr(upper, current_attrib_name, value) -def set_attribute_value(obj: any, attribute_name_rgx, value: Any): +def set_attribute_value(obj: any, attribute_name_rgx, value: Any) -> None: copy_attributes(obj_in={attribute_name_rgx: value}, obj_out=obj, ignore_case=True) def copy_attributes( - obj_in: any, + obj_in: Any, obj_out: Any, only_existing_attributes: bool = True, ignore_case: bool = True, @@ -1081,7 +1111,7 @@ def copy_attributes( p_list = search_attribute_matching_name_with_path( obj=obj_out, name_rgx=k_in, - re_flags=re.IGNORECASE if ignore_case else 0, + re_flags=re.IGNORECASE if ignore_case else 0, # re.NOFLAG only available in Python 3.11+ deep_search=False, search_in_sub_obj=False, ) @@ -1337,7 +1367,7 @@ def get_qualified_type_from_class(cls: Union[type, Any], print_dev_version=True) return None -def get_object_uri(obj: any, dataspace: Optional[str] = None) -> Optional[Uri]: +def get_object_uri(obj: Any, dataspace: Optional[str] = None) -> Optional[Uri]: """Returns an ETP URI""" return parse_uri(f"eml:///dataspace('{dataspace or ''}')/{get_qualified_type_from_class(obj)}({get_obj_uuid(obj)})") @@ -1522,6 +1552,12 @@ def _gen_str_from_attribute_name(attribute_name: Optional[str], _parent_class: O :param _parent_class: :return: """ + if attribute_name is None: + return ( + "A random str (" + + str(random_value_from_class(int)) + + ") @_gen_str_from_attribute_name attribute 'attribute_name' was None" + ) attribute_name_lw = attribute_name.lower() if attribute_name is not None: if attribute_name_lw == "uuid" or attribute_name_lw == "uid": diff --git a/energyml-utils/src/energyml/utils/manager.py b/energyml-utils/src/energyml/utils/manager.py index 23933b3..10644ad 100644 --- a/energyml-utils/src/energyml/utils/manager.py +++ b/energyml-utils/src/energyml/utils/manager.py @@ -179,7 +179,7 @@ def get_class_pkg(cls): try: p = re.compile(RGX_ENERGYML_MODULE_NAME) match = p.search(cls.__module__) - return match.group("pkg") + return match.group("pkg") # type: ignore except AttributeError as e: logging.error(f"Exception to get class package for '{cls}'") raise e @@ -217,6 +217,8 @@ def reshape_version_from_regex_match( :param nb_digit: The number of digits to keep in the version. :return: The reshaped version string. """ + if match is None: + return "" return reshape_version(match.group("versionNumber"), nb_digit) + ( "dev" + match.group("versionDev") if match.group("versionDev") is not None and print_dev_version else "" ) diff --git a/energyml-utils/src/energyml/utils/storage_interface.py b/energyml-utils/src/energyml/utils/storage_interface.py new file mode 100644 index 0000000..99a58d1 --- /dev/null +++ b/energyml-utils/src/energyml/utils/storage_interface.py @@ -0,0 +1,375 @@ +# Copyright (c) 2023-2024 Geosiris. +# SPDX-License-Identifier: Apache-2.0 +""" +Unified Storage Interface Module + +This module provides a unified interface for reading and writing energyml objects and arrays, +abstracting away whether the data comes from an ETP server, a local EPC file, or an EPC stream reader. + +The storage interface enables applications to work with energyml data without knowing the +underlying storage mechanism, making it easy to switch between server-based and file-based +workflows. + +Key Components: +- EnergymlStorageInterface: Abstract base class defining the storage interface +- ResourceMetadata: Dataclass for object metadata (similar to ETP Resource) +- DataArrayMetadata: Dataclass for array metadata + +Example Usage: + ```python + from energyml.utils.storage_interface import create_storage + + # Use with EPC file + storage = create_storage("my_data.epc") + + # Same API for all implementations! + obj = storage.get_object("uuid.version") or storage.get_object("eml:///dataspace('default')/resqml22.TriangulatedSetRepresentation('uuid')") + metadata_list = storage.list_objects() + array = storage.read_array(obj, "values/0") + storage.put_object(new_obj) + storage.close() + ``` +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional, Union, Tuple + +from energyml.utils.uri import Uri +from energyml.opc.opc import Relationship +import numpy as np + + +@dataclass +class ResourceMetadata: + """ + Metadata for an energyml object, similar to ETP Resource. + + This class provides a unified representation of object metadata across + different storage backends (EPC, EPC Stream, ETP). + """ + + uri: str + """URI of the resource (ETP-style uri or identifier)""" + + uuid: str + """Object UUID""" + + title: str + """Object title/name from citation""" + + object_type: str + """Qualified type (e.g., 'resqml20.obj_TriangulatedSetRepresentation')""" + + content_type: str + """Content type (e.g., 'application/x-resqml+xml;version=2.0;type=obj_TriangulatedSetRepresentation')""" + + version: Optional[str] = None + """Object version (optional)""" + + dataspace: Optional[str] = None + """Dataspace name (primarily for ETP)""" + + created: Optional[datetime] = None + """Creation timestamp""" + + last_changed: Optional[datetime] = None + """Last modification timestamp""" + + source_count: Optional[int] = None + """Number of source relationships (objects this references)""" + + target_count: Optional[int] = None + """Number of target relationships (objects referencing this)""" + + custom_data: Dict[str, Any] = field(default_factory=dict) + """Additional custom metadata""" + + @property + def identifier(self) -> str: + """Get object identifier (uuid.version or uuid if no version)""" + if self.version: + return f"{self.uuid}.{self.version}" + return self.uuid + + +@dataclass +class DataArrayMetadata: + """ + Metadata for a data array in an energyml object. + + This provides information about arrays stored in HDF5 or other external storage, + similar to ETP DataArrayMetadata. + """ + + path_in_resource: Optional[str] + """Path to the array within the HDF5 file""" + + array_type: str + """Data type of the array (e.g., 'double', 'int', 'string')""" + + dimensions: List[int] + """Array dimensions/shape""" + + custom_data: Dict[str, Any] = field(default_factory=dict) + """Additional custom metadata""" + + @property + def size(self) -> int: + """Total number of elements in the array""" + result = 1 + for dim in self.dimensions: + result *= dim + return result + + @property + def ndim(self) -> int: + """Number of dimensions""" + return len(self.dimensions) + + @classmethod + def from_numpy_array(cls, path_in_resource: Optional[str], array: np.ndarray) -> "DataArrayMetadata": + """ + Create DataArrayMetadata from a numpy array. + + Args: + path_in_resource: Path to the array within the HDF5 file + array: Numpy array + Returns: + DataArrayMetadata instance + """ + return cls( + path_in_resource=path_in_resource, + array_type=str(array.dtype), + dimensions=list(array.shape), + ) + + @classmethod + def from_list(cls, path_in_resource: Optional[str], data: List[Any]) -> "DataArrayMetadata": + """ + Create DataArrayMetadata from a list. + + Args: + path_in_resource: Path to the array within the HDF5 file + data: List of data + Returns: + DataArrayMetadata instance + """ + array = np.array(data) + return cls.from_numpy_array(path_in_resource, array) + + +class EnergymlStorageInterface(ABC): + """ + Abstract base class for energyml data storage operations. + + This interface defines a common API for interacting with energyml objects and arrays, + regardless of whether they are stored on an ETP server, in a local EPC file, or in + a streaming EPC reader. + + All implementations must provide methods for: + - Getting, putting, and deleting energyml objects + - Reading and writing data arrays + - Getting array metadata + - Listing available objects with metadata + - Transaction support (where applicable) + - Closing the storage connection + """ + + @abstractmethod + def get_object(self, identifier: Union[str, Uri]) -> Optional[Any]: + """ + Retrieve an object by its identifier (UUID or UUID.version). + + Args: + identifier: Object identifier (UUID or UUID.version) or ETP URI + + Returns: + The deserialized energyml object, or None if not found + """ + pass + + @abstractmethod + def get_object_by_uuid(self, uuid: str) -> List[Any]: + """ + Retrieve all objects with the given UUID (all versions). + + Args: + uuid: Object UUID + + Returns: + List of objects with this UUID (may be empty) + """ + pass + + @abstractmethod + def put_object(self, obj: Any, dataspace: Optional[str] = None) -> Optional[str]: + """ + Store an energyml object. + + Args: + obj: The energyml object to store + dataspace: Optional dataspace name (primarily for ETP) + + Returns: + The identifier of the stored object (UUID.version or UUID), or None on error + """ + pass + + @abstractmethod + def delete_object(self, identifier: Union[str, Uri]) -> bool: + """ + Delete an object by its identifier. + + Args: + identifier: Object identifier (UUID or UUID.version) or ETP URI + + Returns: + True if successfully deleted, False otherwise + """ + pass + + @abstractmethod + def read_array(self, proxy: Union[str, Uri, Any], path_in_external: str) -> Optional[np.ndarray]: + """ + Read a data array from external storage (HDF5). + + Args: + proxy: The object identifier/URI or the object itself that references the array + path_in_external: Path within the HDF5 file (e.g., 'values/0') + + Returns: + The data array as a numpy array, or None if not found + """ + pass + + @abstractmethod + def write_array( + self, + proxy: Union[str, Uri, Any], + path_in_external: str, + array: np.ndarray, + ) -> bool: + """ + Write a data array to external storage (HDF5). + + Args: + proxy: The object identifier/URI or the object itself that references the array + path_in_external: Path within the HDF5 file (e.g., 'values/0') + array: The numpy array to write + + Returns: + True if successfully written, False otherwise + """ + pass + + @abstractmethod + def get_array_metadata( + self, proxy: Union[str, Uri, Any], path_in_external: Optional[str] = None + ) -> Union[DataArrayMetadata, List[DataArrayMetadata], None]: + """ + Get metadata for data array(s). + + Args: + proxy: The object identifier/URI or the object itself that references the array + path_in_external: Optional specific path. If None, returns all array metadata for the object + + Returns: + DataArrayMetadata if path specified, List[DataArrayMetadata] if no path, + or None if not found + """ + pass + + @abstractmethod + def list_objects( + self, dataspace: Optional[str] = None, object_type: Optional[str] = None + ) -> List[ResourceMetadata]: + """ + List all objects with their metadata. + + Args: + dataspace: Optional dataspace filter (primarily for ETP) + object_type: Optional type filter (qualified type, e.g., 'resqml20.obj_Grid2dRepresentation') + + Returns: + List of ResourceMetadata for all matching objects + """ + pass + + @abstractmethod + def get_obj_rels(self, obj: Union[str, Uri, Any]) -> List[Relationship]: + """Get relationships for an object. + + Args: + obj: The object identifier/URI or the object itself + + Returns: + List of Relationship objects + """ + pass + + @abstractmethod + def close(self) -> None: + """ + Close the storage connection and release resources. + """ + pass + + # Transaction support (optional, may raise NotImplementedError) + + def start_transaction(self) -> bool: + """ + Start a transaction (if supported). + + Returns: + True if transaction started, False if not supported + """ + raise NotImplementedError("Transactions not supported by this storage backend") + + def commit_transaction(self) -> Tuple[bool, Optional[str]]: + """ + Commit the current transaction (if supported). + + Returns: + Tuple of (success, transaction_uuid) + """ + raise NotImplementedError("Transactions not supported by this storage backend") + + def rollback_transaction(self) -> bool: + """ + Rollback the current transaction (if supported). + + Returns: + True if rolled back successfully + """ + raise NotImplementedError("Transactions not supported by this storage backend") + + # Additional utility methods + + def get_object_dependencies(self, identifier: Union[str, Uri]) -> List[str]: + """ + Get list of object identifiers that this object depends on (references). + + Args: + identifier: Object identifier + + Returns: + List of identifiers of objects this object references + """ + raise NotImplementedError("Dependency tracking not implemented by this storage backend") + + def __enter__(self): + """Context manager entry""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit""" + self.close() + + +__all__ = [ + "EnergymlStorageInterface", + "ResourceMetadata", + "DataArrayMetadata", +] diff --git a/energyml-utils/src/energyml/utils/workspace.py b/energyml-utils/src/energyml/utils/workspace.py deleted file mode 100644 index 8371644..0000000 --- a/energyml-utils/src/energyml/utils/workspace.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2023-2024 Geosiris. -# SPDX-License-Identifier: Apache-2.0 -from abc import abstractmethod -from dataclasses import dataclass -from typing import Optional, Any, Union - -from energyml.utils.uri import Uri -import numpy as np - - -@dataclass -class EnergymlWorkspace: - def get_object(self, uuid: str, object_version: Optional[str]) -> Optional[Any]: - raise NotImplementedError("EnergymlWorkspace.get_object") - - def get_object_by_identifier(self, identifier: str) -> Optional[Any]: - _tmp = identifier.split(".") - return self.get_object(_tmp[0], _tmp[1] if len(_tmp) > 1 else None) - - def get_object_by_uuid(self, uuid: str) -> Optional[Any]: - return self.get_object(uuid, None) - - # def read_external_array( - # self, - # energyml_array: Any, - # root_obj: Optional[Any] = None, - # path_in_root: Optional[str] = None, - # ) -> List[Any]: - # raise NotImplementedError("EnergymlWorkspace.get_object") - - @abstractmethod - def add_object(self, obj: Any) -> bool: - raise NotImplementedError("EnergymlWorkspace.add_object") - - @abstractmethod - def remove_object(self, identifier: Union[str, Uri]) -> None: - raise NotImplementedError("EnergymlWorkspace.remove_object") - - @abstractmethod - def read_array(self, proxy: Union[str, Uri, Any], path_in_external: str) -> Optional[np.ndarray]: - raise NotImplementedError("EnergymlWorkspace.read_array") - - @abstractmethod - def write_array(self, proxy: Union[str, Uri, Any], path_in_external: str, array: Any) -> bool: - raise NotImplementedError("EnergymlWorkspace.write_array") diff --git a/energyml-utils/tests/test_epc.py b/energyml-utils/tests/test_epc.py index 11626a8..de6ea53 100644 --- a/energyml-utils/tests/test_epc.py +++ b/energyml-utils/tests/test_epc.py @@ -9,13 +9,13 @@ from energyml.resqml.v2_0_1.resqmlv2 import FaultInterpretation from energyml.resqml.v2_2.resqmlv2 import TriangulatedSetRepresentation -from src.energyml.utils.epc import ( +from energyml.utils.epc import ( as_dor, get_obj_identifier, gen_energyml_object_path, EpcExportVersion, ) -from src.energyml.utils.introspection import ( +from energyml.utils.introspection import ( epoch_to_date, epoch, gen_uuid, diff --git a/energyml-utils/tests/test_epc_stream.py b/energyml-utils/tests/test_epc_stream.py new file mode 100644 index 0000000..f22824c --- /dev/null +++ b/energyml-utils/tests/test_epc_stream.py @@ -0,0 +1,934 @@ +# Copyright (c) 2023-2024 Geosiris. +# SPDX-License-Identifier: Apache-2.0 +""" +Comprehensive unit tests for EpcStreamReader functionality. + +Tests cover: +1. Relationship update modes (UPDATE_AT_MODIFICATION, UPDATE_ON_CLOSE, MANUAL) +2. Object lifecycle (add, update, remove) +3. Relationship consistency +4. Performance and caching +5. Edge cases and error handling +""" +import os +import tempfile +import zipfile +from pathlib import Path + +import pytest +import numpy as np + +from energyml.eml.v2_3.commonv2 import Citation, DataObjectReference +from energyml.resqml.v2_2.resqmlv2 import ( + TriangulatedSetRepresentation, + BoundaryFeatureInterpretation, + BoundaryFeature, + HorizonInterpretation, +) +from energyml.opc.opc import Relationships + +from energyml.utils.epc_stream import EpcStreamReader, RelsUpdateMode +from energyml.utils.epc import create_energyml_object, as_dor, get_obj_identifier +from energyml.utils.introspection import ( + epoch_to_date, + epoch, + gen_uuid, + get_direct_dor_list, +) +from energyml.utils.constants import EPCRelsRelationshipType +from energyml.utils.serialization import read_energyml_xml_bytes + + +@pytest.fixture +def temp_epc_file(): + """Create a temporary EPC file path for testing.""" + # Create temp file path but don't create the file itself + # Let EpcStreamReader create it + fd, temp_path = tempfile.mkstemp(suffix=".epc") + os.close(fd) # Close the file descriptor + os.unlink(temp_path) # Remove the empty file + + yield temp_path + + # Cleanup + if os.path.exists(temp_path): + os.unlink(temp_path) + + +@pytest.fixture +def sample_objects(): + """Create sample EnergyML objects for testing.""" + # Create a BoundaryFeature + bf = BoundaryFeature( + citation=Citation( + title="Test Boundary Feature", + originator="Test", + creation=epoch_to_date(epoch()), + ), + uuid=gen_uuid(), + object_version="1.0", + ) + + # Create a BoundaryFeatureInterpretation + bfi = BoundaryFeatureInterpretation( + citation=Citation( + title="Test Boundary Feature Interpretation", + originator="Test", + creation=epoch_to_date(epoch()), + ), + uuid=gen_uuid(), + object_version="1.0", + interpreted_feature=as_dor(bf), + ) + + # Create a TriangulatedSetRepresentation + trset = TriangulatedSetRepresentation( + citation=Citation( + title="Test TriangulatedSetRepresentation", + originator="Test", + creation=epoch_to_date(epoch()), + ), + uuid=gen_uuid(), + object_version="1.0", + represented_object=as_dor(bfi), + ) + + # Create a HorizonInterpretation (independent object) + horizon_interp = HorizonInterpretation( + citation=Citation( + title="Test HorizonInterpretation", + originator="Test", + creation=epoch_to_date(epoch()), + ), + uuid=gen_uuid(), + object_version="1.0", + domain="depth", + ) + + return { + "bf": bf, + "bfi": bfi, + "trset": trset, + "horizon_interp": horizon_interp, + } + + +class TestRelsUpdateModes: + """Test different relationship update modes.""" + + def test_manual_mode_no_auto_rebuild(self, temp_epc_file, sample_objects): + """Test that MANUAL mode does not automatically rebuild relationships on close.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.MANUAL) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + + # Add objects in MANUAL mode + reader.add_object(bf) + reader.add_object(bfi) + + # Close without rebuild (MANUAL mode should not call rebuild_all_rels) + reader.close() + + # Reopen and check - rels should exist from _add_object_to_file + # but they won't be "rebuilt" from scratch + reader2 = EpcStreamReader(temp_epc_file) + + # Objects should be there + assert len(reader2) == 2 + + # Basic rels should exist (from _add_object_to_file) + bfi_rels = reader2.get_obj_rels(get_obj_identifier(bfi)) + assert len(bfi_rels) > 0 # Should have SOURCE rels + + reader2.close() + + def test_update_on_close_mode(self, temp_epc_file, sample_objects): + """Test that UPDATE_ON_CLOSE mode rebuilds rels on close.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_ON_CLOSE) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + trset = sample_objects["trset"] + + # Add objects + reader.add_object(bf) + reader.add_object(bfi) + reader.add_object(trset) + + # Before closing, rels may not be complete + reader.close() + + # Reopen and verify relationships were built + reader2 = EpcStreamReader(temp_epc_file) + + # Check that bfi has a SOURCE relationship to bf + bfi_rels = reader2.get_obj_rels(get_obj_identifier(bfi)) + source_rels = [r for r in bfi_rels if r.type_value == EPCRelsRelationshipType.SOURCE_OBJECT.get_type()] + assert len(source_rels) >= 1, "Expected SOURCE relationship from bfi to bf" + + # Check that bf has a DESTINATION relationship from bfi + bf_rels = reader2.get_obj_rels(get_obj_identifier(bf)) + dest_rels = [r for r in bf_rels if r.type_value == EPCRelsRelationshipType.DESTINATION_OBJECT.get_type()] + assert len(dest_rels) >= 1, "Expected DESTINATION relationship from bfi to bf" + + reader2.close() + + def test_update_at_modification_mode_add(self, temp_epc_file, sample_objects): + """Test that UPDATE_AT_MODIFICATION mode updates rels immediately on add.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + + # Add objects + reader.add_object(bf) + reader.add_object(bfi) + + # Check relationships immediately (without closing) + bfi_rels = reader.get_obj_rels(get_obj_identifier(bfi)) + source_rels = [r for r in bfi_rels if r.type_value == EPCRelsRelationshipType.SOURCE_OBJECT.get_type()] + assert len(source_rels) >= 1, "Expected immediate SOURCE relationship from bfi to bf" + + bf_rels = reader.get_obj_rels(get_obj_identifier(bf)) + dest_rels = [r for r in bf_rels if r.type_value == EPCRelsRelationshipType.DESTINATION_OBJECT.get_type()] + assert len(dest_rels) >= 1, "Expected immediate DESTINATION relationship from bfi to bf" + + reader.close() + + def test_update_at_modification_mode_remove(self, temp_epc_file, sample_objects): + """Test that UPDATE_AT_MODIFICATION mode cleans up rels on remove.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + + # Add objects + reader.add_object(bf) + reader.add_object(bfi) + + # Verify relationships exist + bf_rels_before = reader.get_obj_rels(get_obj_identifier(bf)) + assert len(bf_rels_before) > 0, "Expected relationships before removal" + + # Remove bfi + reader.remove_object(get_obj_identifier(bfi)) + + # Check that bf's rels no longer has references to bfi + bf_rels_after = reader.get_obj_rels(get_obj_identifier(bf)) + bfi_refs = [r for r in bf_rels_after if get_obj_identifier(bfi) in r.id] + assert len(bfi_refs) == 0, "Expected no references to removed object" + + reader.close() + + def test_update_at_modification_mode_update(self, temp_epc_file, sample_objects): + """Test that UPDATE_AT_MODIFICATION mode updates rels on object modification.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + trset = sample_objects["trset"] + + # Add initial objects + reader.add_object(bf) + reader.add_object(bfi) + reader.add_object(trset) + + # Modify bfi to reference a different feature (create new one) + bf2 = BoundaryFeature( + citation=Citation( + title="Test Boundary Feature 2", + originator="Test", + creation=epoch_to_date(epoch()), + ), + uuid=gen_uuid(), + object_version="1.0", + ) + reader.add_object(bf2) + + # Update bfi to reference bf2 instead of bf + bfi_modified = BoundaryFeatureInterpretation( + citation=bfi.citation, + uuid=bfi.uuid, + object_version=bfi.object_version, + interpreted_feature=as_dor(bf2), + ) + + reader.update_object(bfi_modified) + + # Check that bf no longer has DESTINATION relationship from bfi + bf_rels = reader.get_obj_rels(get_obj_identifier(bf)) + bfi_dest_rels = [ + r + for r in bf_rels + if r.type_value == EPCRelsRelationshipType.DESTINATION_OBJECT.get_type() and get_obj_identifier(bfi) in r.id + ] + assert len(bfi_dest_rels) == 0, "Expected old DESTINATION relationship to be removed" + + # Check that bf2 now has DESTINATION relationship from bfi + bf2_rels = reader.get_obj_rels(get_obj_identifier(bf2)) + bfi_dest_rels2 = [ + r + for r in bf2_rels + if r.type_value == EPCRelsRelationshipType.DESTINATION_OBJECT.get_type() and get_obj_identifier(bfi) in r.id + ] + assert len(bfi_dest_rels2) >= 1, "Expected new DESTINATION relationship to be added" + + reader.close() + + +class TestObjectLifecycle: + """Test object lifecycle operations.""" + + def test_add_object(self, temp_epc_file, sample_objects): + """Test adding objects to EPC.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + identifier = reader.add_object(bf) + + assert identifier == get_obj_identifier(bf) + assert identifier in reader._metadata + assert reader.get_object_by_identifier(identifier) is not None + + reader.close() + + def test_remove_object(self, temp_epc_file, sample_objects): + """Test removing objects from EPC.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + identifier = reader.add_object(bf) + + result = reader.remove_object(identifier) + assert result is True + assert identifier not in reader._metadata + assert reader.get_object_by_identifier(identifier) is None + + reader.close() + + def test_update_object(self, temp_epc_file, sample_objects): + """Test updating existing objects.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + identifier = reader.add_object(bf) + + # Modify the object + bf_modified = BoundaryFeature( + citation=Citation( + title="Modified Title", + originator="Test", + creation=epoch_to_date(epoch()), + ), + uuid=bf.uuid, + object_version=bf.object_version, + ) + + new_identifier = reader.update_object(bf_modified) + assert new_identifier == identifier + + # Verify the object was updated + obj = reader.get_object_by_identifier(identifier) + assert obj.citation.title == "Modified Title" + + reader.close() + + def test_replace_if_exists(self, temp_epc_file, sample_objects): + """Test replace_if_exists parameter.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + identifier = reader.add_object(bf) + + # Try to add same object again with replace_if_exists=False + with pytest.raises((ValueError, RuntimeError)) as exc_info: + reader.add_object(bf, replace_if_exists=False) + # The error message should mention the object already exists + assert "already exists" in str(exc_info.value).lower() + + # Should work with replace_if_exists=True (default) + identifier2 = reader.add_object(bf, replace_if_exists=True) + assert identifier == identifier2 + + reader.close() + + +class TestRelationshipConsistency: + """Test relationship consistency and correctness.""" + + def test_bidirectional_relationships(self, temp_epc_file, sample_objects): + """Test that SOURCE and DESTINATION relationships are bidirectional.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + + reader.add_object(bf) + reader.add_object(bfi) + + # Check bfi -> bf (SOURCE) + bfi_rels = reader.get_obj_rels(get_obj_identifier(bfi)) + bfi_source_to_bf = [ + r + for r in bfi_rels + if r.type_value == EPCRelsRelationshipType.SOURCE_OBJECT.get_type() and get_obj_identifier(bf) in r.id + ] + assert len(bfi_source_to_bf) >= 1 + + # Check bf -> bfi (DESTINATION) + bf_rels = reader.get_obj_rels(get_obj_identifier(bf)) + bf_dest_from_bfi = [ + r + for r in bf_rels + if r.type_value == EPCRelsRelationshipType.DESTINATION_OBJECT.get_type() and get_obj_identifier(bfi) in r.id + ] + assert len(bf_dest_from_bfi) >= 1 + + reader.close() + + def test_cascade_relationships(self, temp_epc_file, sample_objects): + """Test relationships in a chain: trset -> bfi -> bf.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + trset = sample_objects["trset"] + + reader.add_object(bf) + reader.add_object(bfi) + reader.add_object(trset) + + # Check trset -> bfi + trset_rels = reader.get_obj_rels(get_obj_identifier(trset)) + trset_to_bfi = [ + r + for r in trset_rels + if r.type_value == EPCRelsRelationshipType.SOURCE_OBJECT.get_type() and get_obj_identifier(bfi) in r.id + ] + assert len(trset_to_bfi) >= 1 + + # Check bfi -> bf + bfi_rels = reader.get_obj_rels(get_obj_identifier(bfi)) + bfi_to_bf = [ + r + for r in bfi_rels + if r.type_value == EPCRelsRelationshipType.SOURCE_OBJECT.get_type() and get_obj_identifier(bf) in r.id + ] + assert len(bfi_to_bf) >= 1 + + # Check bf has 2 DESTINATION relationships (from bfi and indirectly from trset) + bf_rels = reader.get_obj_rels(get_obj_identifier(bf)) + bf_dest_rels = [r for r in bf_rels if r.type_value == EPCRelsRelationshipType.DESTINATION_OBJECT.get_type()] + assert len(bf_dest_rels) >= 1 + + reader.close() + + def test_independent_objects_no_rels(self, temp_epc_file, sample_objects): + """Test that independent objects don't have relationships between two boundary features.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + # Use two boundary features with no references to each other + bf1 = sample_objects["bf"] + bf2 = BoundaryFeature( + uuid="00000000-0000-0000-0000-000000000099", + citation=Citation(title="Second Boundary Feature", originator="Test", creation="2026-01-01T00:00:00Z"), + ) + + reader.add_object(bf1) + reader.add_object(bf2) + + # Check that bf2 has no relationships to bf1 + bf2_rels = reader.get_obj_rels(get_obj_identifier(bf2)) + bf1_refs = [r for r in bf2_rels if get_obj_identifier(bf1) in r.id] + assert len(bf1_refs) == 0 + + reader.close() + + +class TestCachingAndPerformance: + """Test caching functionality and performance optimizations.""" + + def test_cache_hit_rate(self, temp_epc_file, sample_objects): + """Test that cache is working properly.""" + reader = EpcStreamReader(temp_epc_file, cache_size=10) + + bf = sample_objects["bf"] + identifier = reader.add_object(bf) + + # First access - cache miss + obj1 = reader.get_object_by_identifier(identifier) + stats1 = reader.get_statistics() + + # Second access - cache hit + obj2 = reader.get_object_by_identifier(identifier) + stats2 = reader.get_statistics() + + assert stats2.cache_hits >= stats1.cache_hits + assert obj1 is obj2 # Should be same object reference + + reader.close() + + def test_metadata_access_without_loading(self, temp_epc_file, sample_objects): + """Test that metadata can be accessed without loading full objects.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + + reader.add_object(bf) + reader.add_object(bfi) + + reader.close() + + # Reopen and access metadata + reader2 = EpcStreamReader(temp_epc_file, preload_metadata=True) + + # Check that we can list objects without loading them + metadata_list = reader2.list_object_metadata() + assert len(metadata_list) == 2 + assert reader2.stats.loaded_objects == 0, "Expected no objects loaded when accessing metadata" + + reader2.close() + + def test_lazy_loading(self, temp_epc_file, sample_objects): + """Test that objects are loaded on-demand.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + trset = sample_objects["trset"] + + reader.add_object(bf) + reader.add_object(bfi) + reader.add_object(trset) + + reader.close() + + # Reopen + reader2 = EpcStreamReader(temp_epc_file) + assert len(reader2) == 3 + assert reader2.stats.loaded_objects == 0, "Expected no objects loaded initially" + + # Load one object + reader2.get_object_by_identifier(get_obj_identifier(bf)) + assert reader2.stats.loaded_objects == 1, "Expected exactly 1 object loaded" + + reader2.close() + + +class TestHelperMethods: + """Test helper methods for rels path generation.""" + + def test_gen_rels_path_from_metadata(self, temp_epc_file, sample_objects): + """Test generating rels path from metadata.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + identifier = reader.add_object(bf) + + metadata = reader._metadata[identifier] + rels_path = reader._gen_rels_path_from_metadata(metadata) + + assert rels_path is not None + assert "_rels/" in rels_path + assert ".rels" in rels_path + + reader.close() + + def test_gen_rels_path_from_identifier(self, temp_epc_file, sample_objects): + """Test generating rels path from identifier.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + identifier = reader.add_object(bf) + + rels_path = reader._gen_rels_path_from_identifier(identifier) + + assert rels_path is not None + assert "_rels/" in rels_path + assert ".rels" in rels_path + + reader.close() + + +class TestModeManagement: + """Test mode switching and management.""" + + def test_set_rels_update_mode(self, temp_epc_file): + """Test changing the relationship update mode.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.MANUAL) + + assert reader.get_rels_update_mode() == RelsUpdateMode.MANUAL + + reader.set_rels_update_mode(RelsUpdateMode.UPDATE_AT_MODIFICATION) + assert reader.get_rels_update_mode() == RelsUpdateMode.UPDATE_AT_MODIFICATION + + reader.close() + + def test_invalid_mode_raises_error(self, temp_epc_file): + """Test that invalid mode raises error.""" + reader = EpcStreamReader(temp_epc_file) + + with pytest.raises(ValueError): + reader.set_rels_update_mode("invalid_mode") + + reader.close() + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_remove_nonexistent_object(self, temp_epc_file): + """Test removing an object that doesn't exist.""" + reader = EpcStreamReader(temp_epc_file) + + result = reader.remove_object("nonexistent-uuid.0") + assert result is False + + reader.close() + + def test_update_nonexistent_object(self, temp_epc_file, sample_objects): + """Test updating an object that doesn't exist.""" + reader = EpcStreamReader(temp_epc_file) + + bf = sample_objects["bf"] + + with pytest.raises(ValueError): + reader.update_object(bf) + + reader.close() + + def test_empty_epc_operations(self, temp_epc_file): + """Test operations on empty EPC.""" + reader = EpcStreamReader(temp_epc_file) + + assert len(reader) == 0 + assert len(reader.list_object_metadata()) == 0 + + reader.close() + + def test_multiple_add_remove_cycles(self, temp_epc_file, sample_objects): + """Test multiple add/remove cycles.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + bf = sample_objects["bf"] + + for _ in range(3): + identifier = reader.add_object(bf) + assert identifier in reader._metadata + + reader.remove_object(identifier) + assert identifier not in reader._metadata + + reader.close() + + +class TestRebuildAllRels: + """Test the rebuild_all_rels functionality.""" + + def test_rebuild_all_rels_manual_mode(self, temp_epc_file, sample_objects): + """Test manually rebuilding relationships in MANUAL mode.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.MANUAL) + + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + + reader.add_object(bf) + reader.add_object(bfi) + + # Manually rebuild relationships + stats = reader.rebuild_all_rels(clean_first=True) + + assert stats["objects_processed"] == 2 + assert stats["source_relationships"] >= 1 + assert stats["destination_relationships"] >= 1 + + # Verify relationships exist now + bfi_rels = reader.get_obj_rels(get_obj_identifier(bfi)) + assert len(bfi_rels) > 0 + + reader.close() + + +class TestArrayOperations: + """Test HDF5 array operations.""" + + def test_write_read_array(self, temp_epc_file, sample_objects): + """Test writing and reading arrays.""" + # Create temp HDF5 file + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".h5") as f: + h5_path = f.name + + try: + reader = EpcStreamReader(temp_epc_file, force_h5_path=h5_path) + + trset = sample_objects["trset"] + reader.add_object(trset) + + # Write array + test_array = np.arange(12).reshape((3, 4)) + success = reader.write_array(trset, "/test_dataset", test_array) + assert success + + # Read array back + read_array = reader.read_array(trset, "/test_dataset") + assert read_array is not None + assert np.array_equal(read_array, test_array) + + # Close reader before deleting files + reader.close() + finally: + # Give time for file handles to be released + import time + + time.sleep(0.1) + if os.path.exists(h5_path): + try: + os.unlink(h5_path) + except PermissionError: + pass # File still locked, skip cleanup + + +class TestAdditionalRelsPreservation: + """Test that manually added relationships (like EXTERNAL_RESOURCE) are preserved during updates.""" + + def test_external_resource_preserved_on_object_update(self, temp_epc_file, sample_objects): + """Test that EXTERNAL_RESOURCE relationships are preserved when the object is updated.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + # Add initial object + trset = sample_objects["trset"] + identifier = reader.add_object(trset) + + # Add EXTERNAL_RESOURCE relationship manually + from energyml.opc.opc import Relationship + + h5_rel = Relationship( + target="data/test_data.h5", + type_value=EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(), + id=f"_external_{identifier}_h5", + ) + reader.add_rels_for_object(identifier, [h5_rel], write_immediately=True) + + # Verify the HDF5 path is returned + h5_paths_before = reader.get_h5_file_paths(identifier) + assert "data/test_data.h5" in h5_paths_before + + # Update the object (modify its title) + trset.citation.title = "Updated Triangulated Set" + reader.update_object(trset) + + # Verify EXTERNAL_RESOURCE relationship is still present + h5_paths_after = reader.get_h5_file_paths(identifier) + assert "data/test_data.h5" in h5_paths_after, "EXTERNAL_RESOURCE relationship was lost after update" + + # Also verify by checking rels directly + rels = reader.get_obj_rels(identifier) + external_rels = [r for r in rels if r.type_value == EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type()] + assert len(external_rels) > 0, "EXTERNAL_RESOURCE relationship not found in rels" + assert any("test_data.h5" in r.target for r in external_rels) + + reader.close() + + def test_external_resource_preserved_when_referenced_by_other(self, temp_epc_file, sample_objects): + """Test that EXTERNAL_RESOURCE relationships are preserved when another object references this one.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + # Add BoundaryFeature with EXTERNAL_RESOURCE + bf = sample_objects["bf"] + bf_id = reader.add_object(bf) + + # Add EXTERNAL_RESOURCE relationship to BoundaryFeature + from energyml.opc.opc import Relationship + + h5_rel = Relationship( + target="data/boundary_data.h5", + type_value=EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(), + id=f"_external_{bf_id}_h5", + ) + reader.add_rels_for_object(bf_id, [h5_rel], write_immediately=True) + + # Verify initial state + h5_paths_initial = reader.get_h5_file_paths(bf_id) + assert "data/boundary_data.h5" in h5_paths_initial + + # Add BoundaryFeatureInterpretation that references the BoundaryFeature + # This will create DESTINATION_OBJECT relationship in bf's rels file + bfi = sample_objects["bfi"] + reader.add_object(bfi) + + # Verify EXTERNAL_RESOURCE is still present after adding referencing object + h5_paths_after = reader.get_h5_file_paths(bf_id) + assert "data/boundary_data.h5" in h5_paths_after, "EXTERNAL_RESOURCE lost after adding referencing object" + + # Verify rels directly + rels = reader.get_obj_rels(bf_id) + external_rels = [r for r in rels if r.type_value == EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type()] + assert len(external_rels) > 0 + assert any("boundary_data.h5" in r.target for r in external_rels) + + reader.close() + + def test_external_resource_preserved_update_on_close_mode(self, temp_epc_file, sample_objects): + """Test EXTERNAL_RESOURCE preservation in UPDATE_ON_CLOSE mode.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_ON_CLOSE) + + # Add object + trset = sample_objects["trset"] + identifier = reader.add_object(trset) + + # Add EXTERNAL_RESOURCE relationship + from energyml.opc.opc import Relationship + + h5_rel = Relationship( + target="data/test_data.h5", + type_value=EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(), + id=f"_external_{identifier}_h5", + ) + reader.add_rels_for_object(identifier, [h5_rel], write_immediately=True) + + # Update object + trset.citation.title = "Modified in UPDATE_ON_CLOSE mode" + reader.update_object(trset) + + # Close (triggers rebuild_all_rels in UPDATE_ON_CLOSE mode) + reader.close() + + # Reopen and verify + reader2 = EpcStreamReader(temp_epc_file) + h5_paths = reader2.get_h5_file_paths(identifier) + assert "data/test_data.h5" in h5_paths, "EXTERNAL_RESOURCE lost after close in UPDATE_ON_CLOSE mode" + reader2.close() + + def test_multiple_external_resources_preserved(self, temp_epc_file, sample_objects): + """Test that multiple EXTERNAL_RESOURCE relationships are all preserved.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + # Add object + trset = sample_objects["trset"] + identifier = reader.add_object(trset) + + # Add multiple EXTERNAL_RESOURCE relationships + from energyml.opc.opc import Relationship + + h5_rels = [ + Relationship( + target="data/geometry.h5", + type_value=EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(), + id=f"_external_{identifier}_geometry", + ), + Relationship( + target="data/properties.h5", + type_value=EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(), + id=f"_external_{identifier}_properties", + ), + Relationship( + target="data/metadata.h5", + type_value=EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(), + id=f"_external_{identifier}_metadata", + ), + ] + reader.add_rels_for_object(identifier, h5_rels, write_immediately=True) + + # Verify all are present + h5_paths_before = reader.get_h5_file_paths(identifier) + assert "data/geometry.h5" in h5_paths_before + assert "data/properties.h5" in h5_paths_before + assert "data/metadata.h5" in h5_paths_before + + # Update object + trset.citation.title = "Updated with Multiple H5 Files" + reader.update_object(trset) + + # Verify all EXTERNAL_RESOURCE relationships are still present + h5_paths_after = reader.get_h5_file_paths(identifier) + assert "data/geometry.h5" in h5_paths_after + assert "data/properties.h5" in h5_paths_after + assert "data/metadata.h5" in h5_paths_after + + reader.close() + + def test_external_resource_preserved_cascade_updates(self, temp_epc_file, sample_objects): + """Test EXTERNAL_RESOURCE preserved through cascade of object updates.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + # Create chain: bf <- bfi <- trset + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + trset = sample_objects["trset"] + + # Add all objects + bf_id = reader.add_object(bf) + bfi_id = reader.add_object(bfi) + trset_id = reader.add_object(trset) + + # Add EXTERNAL_RESOURCE to bf (bottom of chain) + from energyml.opc.opc import Relationship + + h5_rel = Relationship( + target="data/bf_data.h5", + type_value=EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(), + id=f"_external_{bf_id}_h5", + ) + reader.add_rels_for_object(bf_id, [h5_rel], write_immediately=True) + + # Verify initial state + h5_paths = reader.get_h5_file_paths(bf_id) + assert "data/bf_data.h5" in h5_paths + + # Update intermediate object (bfi) + bfi.citation.title = "Modified BFI" + reader.update_object(bfi) + + # Update top object (trset) + trset.citation.title = "Modified TriSet" + reader.update_object(trset) + + # Verify EXTERNAL_RESOURCE still present after cascade of updates + h5_paths_final = reader.get_h5_file_paths(bf_id) + assert "data/bf_data.h5" in h5_paths_final, "EXTERNAL_RESOURCE lost after cascade updates" + + reader.close() + + def test_external_resource_with_object_removal(self, temp_epc_file, sample_objects): + """Test that EXTERNAL_RESOURCE is properly handled when referenced object is removed.""" + reader = EpcStreamReader(temp_epc_file, rels_update_mode=RelsUpdateMode.UPDATE_AT_MODIFICATION) + + # Create bf and bfi (bfi references bf) + bf = sample_objects["bf"] + bfi = sample_objects["bfi"] + + bf_id = reader.add_object(bf) + bfi_id = reader.add_object(bfi) + + # Add EXTERNAL_RESOURCE to bfi + from energyml.opc.opc import Relationship + + h5_rel = Relationship( + target="data/bfi_data.h5", + type_value=EPCRelsRelationshipType.EXTERNAL_RESOURCE.get_type(), + id=f"_external_{bfi_id}_h5", + ) + reader.add_rels_for_object(bfi_id, [h5_rel], write_immediately=True) + + # Verify it exists + h5_paths = reader.get_h5_file_paths(bfi_id) + assert "data/bfi_data.h5" in h5_paths + + # Remove bf (which bfi references) + reader.remove_object(bf_id) + + # Update bfi (now its reference to bf is broken, but EXTERNAL_RESOURCE should remain) + bfi.citation.title = "Modified after BF removed" + reader.update_object(bfi) + + # Verify EXTERNAL_RESOURCE is still there + h5_paths_after = reader.get_h5_file_paths(bfi_id) + assert "data/bfi_data.h5" in h5_paths_after, "EXTERNAL_RESOURCE lost after referenced object removal" + + reader.close() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/energyml-utils/tests/test_parallel_rels_performance.py b/energyml-utils/tests/test_parallel_rels_performance.py new file mode 100644 index 0000000..2e1b6fa --- /dev/null +++ b/energyml-utils/tests/test_parallel_rels_performance.py @@ -0,0 +1,309 @@ +""" +Performance benchmarking tests for parallel rebuild_all_rels implementation. + +This module compares sequential vs parallel relationship rebuilding performance +on real EPC files. +""" + +import os +import time +import tempfile +import shutil +from pathlib import Path +import pytest + +from energyml.utils.epc_stream import EpcStreamReader + + +# Default test file path - can be overridden via environment variable +DEFAULT_TEST_FILE = r"C:\Users\Cryptaro\Downloads\80wells_surf.epc" +TEST_EPC_PATH = os.environ.get("TEST_EPC_PATH", DEFAULT_TEST_FILE) + + +def create_test_copy(source_path: str) -> str: + """Create a temporary copy of the EPC file for testing.""" + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "test.epc") + shutil.copy(source_path, temp_path) + return temp_path + + +@pytest.mark.slow +@pytest.mark.skipif(not os.path.exists(TEST_EPC_PATH), reason=f"Test EPC file not found: {TEST_EPC_PATH}") +class TestParallelRelsPerformance: + """Performance comparison tests for sequential vs parallel rebuild_all_rels. + + These tests are marked as 'slow' and skipped by default. + Run with: pytest -m slow + """ + + def test_sequential_rebuild_performance(self): + """Benchmark sequential rebuild_all_rels implementation.""" + # Create test copy + test_file = create_test_copy(TEST_EPC_PATH) + + try: + # Open with sequential mode + reader = EpcStreamReader(test_file, enable_parallel_rels=False, keep_open=True) + + # Measure rebuild time + start_time = time.time() + stats = reader.rebuild_all_rels(clean_first=True) + end_time = time.time() + + execution_time = end_time - start_time + + # Verify stats + assert stats["objects_processed"] > 0, "Should process some objects" + assert stats["source_relationships"] > 0, "Should create SOURCE relationships" + assert stats["rels_files_created"] > 0, "Should create .rels files" + + # Print results + print(f"\n{'='*70}") + print(f"SEQUENTIAL MODE PERFORMANCE") + print(f"{'='*70}") + print(f"Objects processed: {stats['objects_processed']}") + print(f"SOURCE relationships: {stats['source_relationships']}") + print(f"DESTINATION relationships: {stats['destination_relationships']}") + print(f"Rels files created: {stats['rels_files_created']}") + print(f"Execution time: {execution_time:.3f}s") + print(f"Objects per second: {stats['objects_processed']/execution_time:.2f}") + print(f"{'='*70}\n") + + # Close reader before cleanup + reader.close() + + # Allow time for file handles to be released + import time as time_module + + time_module.sleep(0.5) + + # Store for comparison + return {"mode": "sequential", "execution_time": execution_time, "stats": stats} + + finally: + # Cleanup + try: + # Ensure directory is cleaned up + temp_dir = os.path.dirname(test_file) + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir, ignore_errors=True) + except Exception as e: + print(f"Warning: Cleanup failed: {e}") + + def test_parallel_rebuild_performance(self): + """Benchmark parallel rebuild_all_rels implementation.""" + # Create test copy + test_file = create_test_copy(TEST_EPC_PATH) + + try: + # Open with parallel mode + reader = EpcStreamReader(test_file, enable_parallel_rels=True, keep_open=True) + + # Measure rebuild time + start_time = time.time() + stats = reader.rebuild_all_rels(clean_first=True) + end_time = time.time() + + execution_time = end_time - start_time + + # Verify stats + assert stats["objects_processed"] > 0, "Should process some objects" + assert stats["source_relationships"] > 0, "Should create SOURCE relationships" + assert stats["rels_files_created"] > 0, "Should create .rels files" + assert stats["parallel_mode"] is True, "Should indicate parallel mode" + + # Print results + print(f"\n{'='*70}") + print(f"PARALLEL MODE PERFORMANCE") + print(f"{'='*70}") + print(f"Objects processed: {stats['objects_processed']}") + print(f"SOURCE relationships: {stats['source_relationships']}") + print(f"DESTINATION relationships: {stats['destination_relationships']}") + print(f"Rels files created: {stats['rels_files_created']}") + print(f"Execution time: {execution_time:.3f}s") + print(f"Objects per second: {stats['objects_processed']/execution_time:.2f}") + print(f"{'='*70}\n") + + # Close reader before cleanup + reader.close() + + # Allow time for file handles to be released + import time as time_module + + time_module.sleep(0.5) + + return {"mode": "parallel", "execution_time": execution_time, "stats": stats} + + finally: + # Cleanup + try: + temp_dir = os.path.dirname(test_file) + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir, ignore_errors=True) + except Exception as e: + print(f"Warning: Cleanup failed: {e}") + + def test_compare_sequential_vs_parallel(self): + """Direct comparison of sequential vs parallel performance.""" + # Run sequential + test_file_seq = create_test_copy(TEST_EPC_PATH) + + try: + reader_seq = EpcStreamReader(test_file_seq, enable_parallel_rels=False, keep_open=True) + start_seq = time.time() + stats_seq = reader_seq.rebuild_all_rels(clean_first=True) + time_seq = time.time() - start_seq + reader_seq.close() + finally: + if os.path.exists(test_file_seq): + os.unlink(test_file_seq) + if os.path.exists(os.path.dirname(test_file_seq)): + shutil.rmtree(os.path.dirname(test_file_seq)) + + # Run parallel + test_file_par = create_test_copy(TEST_EPC_PATH) + + try: + reader_par = EpcStreamReader(test_file_par, enable_parallel_rels=True, keep_open=True) + start_par = time.time() + stats_par = reader_par.rebuild_all_rels(clean_first=True) + time_par = time.time() - start_par + reader_par.close() + finally: + if os.path.exists(test_file_par): + os.unlink(test_file_par) + if os.path.exists(os.path.dirname(test_file_par)): + shutil.rmtree(os.path.dirname(test_file_par)) + + # Verify consistency + assert stats_seq["objects_processed"] == stats_par["objects_processed"], "Should process same number of objects" + assert ( + stats_seq["source_relationships"] == stats_par["source_relationships"] + ), "Should create same SOURCE relationships" + assert ( + stats_seq["destination_relationships"] == stats_par["destination_relationships"] + ), "Should create same DESTINATION relationships" + + # Calculate speedup + speedup = time_seq / time_par + speedup_percent = (time_seq - time_par) / time_seq * 100 + + # Print comparison + print(f"\n{'='*70}") + print(f"PERFORMANCE COMPARISON") + print(f"{'='*70}") + print(f"Test file: {os.path.basename(TEST_EPC_PATH)}") + print(f"Objects processed: {stats_seq['objects_processed']}") + print(f"-" * 70) + print(f"Sequential time: {time_seq:.3f}s") + print(f"Parallel time: {time_par:.3f}s") + print(f"-" * 70) + print(f"Speedup: {speedup:.2f}x") + print(f"Time saved: {speedup_percent:.1f}%") + print(f"Absolute savings: {time_seq - time_par:.3f}s") + print(f"{'='*70}\n") + + # Assert some improvement (parallel should be faster or at least not much slower) + # For small EPCs, overhead might make parallel slightly slower + # For large EPCs (80+ objects), parallel should be significantly faster + if stats_seq["objects_processed"] >= 50: + assert ( + time_par < time_seq * 1.2 + ), f"Parallel mode should not be >20% slower for {stats_seq['objects_processed']} objects" + + def test_correctness_parallel_vs_sequential(self): + """Verify that parallel and sequential produce identical results.""" + # Test with sequential + test_file_seq = create_test_copy(TEST_EPC_PATH) + + try: + reader_seq = EpcStreamReader(test_file_seq, enable_parallel_rels=False) + stats_seq = reader_seq.rebuild_all_rels(clean_first=True) + + # Read back relationships + rels_seq = {} + for identifier in reader_seq._metadata: + try: + obj_rels = reader_seq.get_obj_rels(identifier) + rels_seq[identifier] = sorted([(r.target, r.type_value) for r in obj_rels]) + except Exception: + rels_seq[identifier] = [] + + reader_seq.close() + finally: + if os.path.exists(test_file_seq): + os.unlink(test_file_seq) + if os.path.exists(os.path.dirname(test_file_seq)): + shutil.rmtree(os.path.dirname(test_file_seq)) + + # Test with parallel + test_file_par = create_test_copy(TEST_EPC_PATH) + + try: + reader_par = EpcStreamReader(test_file_par, enable_parallel_rels=True) + stats_par = reader_par.rebuild_all_rels(clean_first=True) + + # Read back relationships + rels_par = {} + for identifier in reader_par._metadata: + try: + obj_rels = reader_par.get_obj_rels(identifier) + rels_par[identifier] = sorted([(r.target, r.type_value) for r in obj_rels]) + except Exception: + rels_par[identifier] = [] + + reader_par.close() + finally: + if os.path.exists(test_file_par): + os.unlink(test_file_par) + if os.path.exists(os.path.dirname(test_file_par)): + shutil.rmtree(os.path.dirname(test_file_par)) + + # Compare results + assert stats_seq["objects_processed"] == stats_par["objects_processed"] + assert stats_seq["source_relationships"] == stats_par["source_relationships"] + assert stats_seq["destination_relationships"] == stats_par["destination_relationships"] + + # Compare actual relationships (order-independent) + assert set(rels_seq.keys()) == set(rels_par.keys()), "Should have same objects" + + for identifier in rels_seq: + assert ( + rels_seq[identifier] == rels_par[identifier] + ), f"Relationships for {identifier} should match between sequential and parallel modes" + + print(f"\n✓ Correctness verified: Sequential and parallel modes produce identical results") + + +if __name__ == "__main__": + """Run performance tests directly.""" + import sys + + if len(sys.argv) > 1: + TEST_EPC_PATH = sys.argv[1] + + if not os.path.exists(TEST_EPC_PATH): + print(f"Error: Test file not found: {TEST_EPC_PATH}") + print(f"Usage: python {__file__} [path/to/test.epc]") + sys.exit(1) + + print(f"Running performance tests with: {TEST_EPC_PATH}\n") + + # Run tests + test = TestParallelRelsPerformance() + + try: + test.test_sequential_rebuild_performance() + test.test_parallel_rebuild_performance() + test.test_compare_sequential_vs_parallel() + test.test_correctness_parallel_vs_sequential() + + print("\n✓ All performance tests passed!") + + except Exception as e: + print(f"\n✗ Test failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) diff --git a/energyml-utils/tests/test_uri.py b/energyml-utils/tests/test_uri.py index 8bb6044..5dda5a3 100644 --- a/energyml-utils/tests/test_uri.py +++ b/energyml-utils/tests/test_uri.py @@ -1,7 +1,7 @@ # Copyright (c) 2023-2024 Geosiris. # SPDX-License-Identifier: Apache-2.0 -from src.energyml.utils.uri import Uri, parse_uri +from energyml.utils.uri import Uri, parse_uri from energyml.utils.introspection import get_obj_uri from energyml.resqml.v2_0_1.resqmlv2 import TriangulatedSetRepresentation, ObjTriangulatedSetRepresentation diff --git a/energyml-utils/tests/test_xml.py b/energyml-utils/tests/test_xml.py index 4c454af..bfd3309 100644 --- a/energyml-utils/tests/test_xml.py +++ b/energyml-utils/tests/test_xml.py @@ -3,6 +3,7 @@ import logging +from energyml.utils.constants import parse_qualified_type from src.energyml.utils.xml import * CT_20 = "application/x-resqml+xml;version=2.0;type=obj_TriangulatedSetRepresentation"