From b2f6c1b02c1a15779467cfb97b62337b63e5b8ee Mon Sep 17 00:00:00 2001 From: BrianMichell Date: Thu, 23 Apr 2026 14:43:35 +0000 Subject: [PATCH 1/6] Bump version for release --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 70d8792c..547d7ec7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "multidimio" -version = "1.1.2" +version = "1.1.3" description = "Cloud-native, scalable, and user-friendly multi dimensional energy data!" authors = [{ name = "Altay Sansal", email = "altay.sansal@tgs.com" }] requires-python = ">=3.11,<3.14" @@ -183,7 +183,7 @@ init_typed = true warn_required_dynamic_aliases = true [tool.bumpversion] -current_version = "1.1.2" +current_version = "1.1.3" allow_dirty = true commit = false tag = false From 8260464a3efa7dbf5ed678e6b1c7065090800809 Mon Sep 17 00:00:00 2001 From: BrianMichell Date: Mon, 18 May 2026 15:29:50 +0000 Subject: [PATCH 2/6] Modularize segy ingestion code --- src/mdio/converters/segy.py | 356 +-------------------- src/mdio/ingestion/__init__.py | 1 + src/mdio/ingestion/coordinates.py | 77 +++++ src/mdio/ingestion/grid_qc.py | 69 ++++ src/mdio/ingestion/metadata.py | 18 ++ src/mdio/ingestion/segy/__init__.py | 1 + src/mdio/ingestion/segy/coordinates.py | 157 +++++++++ src/mdio/ingestion/segy/file_headers.py | 48 +++ src/mdio/ingestion/segy/header_analysis.py | 275 ++++++++++++++++ src/mdio/ingestion/segy/validation.py | 35 ++ src/mdio/segy/geometry.py | 279 +--------------- uv.lock | 7 +- 12 files changed, 702 insertions(+), 621 deletions(-) create mode 100644 src/mdio/ingestion/__init__.py create mode 100644 src/mdio/ingestion/coordinates.py create mode 100644 src/mdio/ingestion/grid_qc.py create mode 100644 src/mdio/ingestion/metadata.py create mode 100644 src/mdio/ingestion/segy/__init__.py create mode 100644 src/mdio/ingestion/segy/coordinates.py create mode 100644 src/mdio/ingestion/segy/file_headers.py create mode 100644 src/mdio/ingestion/segy/header_analysis.py create mode 100644 src/mdio/ingestion/segy/validation.py diff --git a/src/mdio/converters/segy.py b/src/mdio/converters/segy.py index f0d34549..a8f0ba36 100644 --- a/src/mdio/converters/segy.py +++ b/src/mdio/converters/segy.py @@ -2,7 +2,6 @@ from __future__ import annotations -import base64 import logging from typing import TYPE_CHECKING @@ -10,8 +9,6 @@ import zarr from segy.config import SegyFileSettings from segy.config import SegyHeaderOverrides -from segy.standards.codes import MeasurementSystem as SegyMeasurementSystem -from segy.standards.fields import binary as binary_header_fields from mdio.api.io import _normalize_path from mdio.api.io import to_mdio @@ -20,25 +17,28 @@ from mdio.builder.schemas.compressors import Blosc from mdio.builder.schemas.compressors import BloscCname from mdio.builder.schemas.dtype import ScalarType -from mdio.builder.schemas.v1.units import AngleUnitEnum -from mdio.builder.schemas.v1.units import AngleUnitModel -from mdio.builder.schemas.v1.units import LengthUnitEnum -from mdio.builder.schemas.v1.units import LengthUnitModel from mdio.builder.schemas.v1.variable import VariableMetadata from mdio.builder.xarray_builder import to_xarray_dataset from mdio.constants import ZarrFormat from mdio.converters.exceptions import GridTraceCountError -from mdio.converters.exceptions import GridTraceSparsityError from mdio.converters.type_converter import to_structured_type from mdio.core.config import MDIOSettings from mdio.core.grid import Grid from mdio.core.utils_write import MAX_COORDINATES_BYTES from mdio.core.utils_write import MAX_SIZE_LIVE_MASK from mdio.core.utils_write import get_constrained_chunksize +from mdio.ingestion.coordinates import populate_dim_coordinates # noqa: F401 re-export for compat +from mdio.ingestion.coordinates import populate_non_dim_coordinates # noqa: F401 re-export for compat +from mdio.ingestion.grid_qc import grid_density_qc +from mdio.ingestion.metadata import _add_grid_override_to_metadata +from mdio.ingestion.segy.coordinates import _get_coordinates +from mdio.ingestion.segy.coordinates import _get_spatial_coordinate_unit +from mdio.ingestion.segy.coordinates import _populate_coordinates +from mdio.ingestion.segy.coordinates import _update_template_units +from mdio.ingestion.segy.file_headers import _add_segy_file_headers +from mdio.ingestion.segy.validation import _validate_spec_in_template from mdio.segy import blocked_io from mdio.segy.file import get_segy_file_info -from mdio.segy.scalar import SCALE_COORDINATE_KEYS -from mdio.segy.scalar import _apply_coordinate_scalar from mdio.segy.utilities import get_grid_plan if TYPE_CHECKING: @@ -59,81 +59,6 @@ logger = logging.getLogger(__name__) -MEASUREMENT_SYSTEM_KEY = binary_header_fields.Rev0.MEASUREMENT_SYSTEM_CODE.model.name -ANGLE_UNIT_KEYS = ["angle", "azimuth"] -SPATIAL_UNIT_KEYS = [ - "cdp_x", - "cdp_y", - "source_coord_x", - "source_coord_y", - "group_coord_x", - "group_coord_y", - "offset", -] - - -def grid_density_qc(grid: Grid, num_traces: int) -> None: - """Quality control for sensible grid density during SEG-Y to MDIO conversion. - - This function checks the density of the proposed grid by comparing the total possible traces - (`grid_traces`) to the actual number of traces in the SEG-Y file (`num_traces`). A warning is - logged if the sparsity ratio (`grid_traces / num_traces`) exceeds a configurable threshold, - indicating potential inefficiency or misconfiguration. - - The warning threshold is set via the environment variable `MDIO__GRID__SPARSITY_RATIO_WARN` - (default 2), and the error threshold via `MDIO__GRID__SPARSITY_RATIO_LIMIT` (default 10). To - suppress the exception (but still log warnings), set `MDIO_IGNORE_CHECKS=1`. - - Args: - grid: The Grid instance to check. - num_traces: Expected number of traces from the SEG-Y file. - - Raises: - GridTraceSparsityError: If the sparsity ratio exceeds `MDIO__GRID__SPARSITY_RATIO_LIMIT` - and `MDIO_IGNORE_CHECKS` is not set to a truthy value (e.g., "1", "true"). - """ - settings = MDIOSettings() - # Calculate total possible traces in the grid (excluding sample dimension) - grid_traces = np.prod(grid.shape[:-1], dtype=np.uint64) - - # Handle division by zero if num_traces is 0 - sparsity_ratio = float("inf") if num_traces == 0 else grid_traces / num_traces - - # Fetch and validate environment variables - warning_ratio = settings.grid_sparsity_ratio_warn - error_ratio = settings.grid_sparsity_ratio_limit - ignore_checks = settings.ignore_checks - - # Check sparsity - should_warn = sparsity_ratio > warning_ratio - should_error = sparsity_ratio > error_ratio and not ignore_checks - - # Early return if everything is OK - # Prepare message for warning or error - if not should_warn and not should_error: - return - - # Build warning / error message - dims = dict(zip(grid.dim_names, grid.shape, strict=True)) - msg = ( - f"Ingestion grid is sparse. Sparsity ratio: {sparsity_ratio:.2f}. " - f"Ingestion grid: {dims}. " - f"SEG-Y trace count: {num_traces}, grid trace count: {grid_traces}." - ) - for dim_name in grid.dim_names: - dim_min = grid.get_min(dim_name) - dim_max = grid.get_max(dim_name) - msg += f"\n{dim_name} min: {dim_min} max: {dim_max}" - - # Log warning if sparsity exceeds warning threshold - if should_warn: - logger.warning(msg) - - # Raise error if sparsity exceeds error threshold and checks are not ignored - if should_error: - raise GridTraceSparsityError(grid.shape, num_traces, msg) - - def _patch_add_coordinates_for_non_binned( template: AbstractDatasetTemplate, non_binned_dims: set[str], @@ -342,242 +267,6 @@ def _build_and_check_grid( return grid -def _get_coordinates( - grid: Grid, - segy_headers: SegyHeaderArray, - mdio_template: AbstractDatasetTemplate, -) -> tuple[list[Dimension], dict[str, SegyHeaderArray]]: - """Get the data dim and non-dim coordinates from the SEG-Y headers and MDIO template. - - Select a subset of the segy_dimensions that corresponds to the MDIO dimensions - The dimensions are ordered as in the MDIO template. - The last dimension is always the vertical domain dimension - - Args: - grid: Inferred MDIO grid for SEG-Y file. - segy_headers: Headers read in from SEG-Y file. - mdio_template: The MDIO template to use for the conversion. - - Raises: - ValueError: If a dimension or coordinate name from the MDIO template is not found in - the SEG-Y headers. - - Returns: - A tuple containing: - - A list of dimension coordinates (1-D arrays). - - A dict of non-dimension coordinates (str: N-D arrays). - """ - dimensions_coords = [] - for dim_name in mdio_template.dimension_names: - if dim_name not in grid.dim_names: - err = f"Dimension '{dim_name}' was not found in SEG-Y dimensions." - raise ValueError(err) - dimensions_coords.append(grid.select_dim(dim_name)) - - non_dim_coords: dict[str, SegyHeaderArray] = {} - for coord_name in mdio_template.coordinate_names: - if coord_name not in segy_headers.dtype.names: - err = f"Coordinate '{coord_name}' not found in SEG-Y dimensions." - raise ValueError(err) - # Copy the data to allow segy_headers to be garbage collected - non_dim_coords[coord_name] = np.array(segy_headers[coord_name]) - - return dimensions_coords, non_dim_coords - - -def populate_dim_coordinates( - dataset: xr_Dataset, grid: Grid, drop_vars_delayed: list[str] -) -> tuple[xr_Dataset, list[str]]: - """Populate the xarray dataset with dimension coordinate variables.""" - for dim in grid.dims: - dataset[dim.name].values[:] = dim.coords - drop_vars_delayed.append(dim.name) - return dataset, drop_vars_delayed - - -def populate_non_dim_coordinates( - dataset: xr_Dataset, - grid: Grid, - coordinates: dict[str, SegyHeaderArray], - drop_vars_delayed: list[str], - spatial_coordinate_scalar: int, -) -> tuple[xr_Dataset, list[str]]: - """Populate the xarray dataset with coordinate variables. - - Memory optimization: Processes coordinates one at a time and explicitly - releases intermediate arrays to reduce peak memory usage. - """ - non_data_domain_dims = grid.dim_names[:-1] # minus the data domain dimension - - # Process coordinates one at a time to minimize peak memory - coord_names = list(coordinates.keys()) - for coord_name in coord_names: - coord_values = coordinates.pop(coord_name) # Remove from dict to free memory - da_coord = dataset[coord_name] - - # Get coordinate shape from dataset (uses dask shape, no memory allocation) - coord_shape = da_coord.shape - - # Create output array with fill value - fill_value = da_coord.encoding.get("_FillValue") or da_coord.encoding.get("fill_value") - if fill_value is None: - fill_value = np.nan - tmp_coord_values = np.full(coord_shape, fill_value, dtype=da_coord.dtype) - - # Compute slices for this coordinate's dimensions - coord_axes = tuple(non_data_domain_dims.index(coord_dim) for coord_dim in da_coord.dims) - coord_slices = tuple(slice(None) if idx in coord_axes else 0 for idx in range(len(non_data_domain_dims))) - - # Read only the required slice from grid map - coord_trace_indices = np.asarray(grid.map[coord_slices]) - - # Find valid (non-null) indices - not_null = coord_trace_indices != grid.map.fill_value - - # Populate values efficiently - if not_null.any(): - valid_indices = coord_trace_indices[not_null] - tmp_coord_values[not_null] = coord_values[valid_indices] - - # Apply scalar if needed - if coord_name in SCALE_COORDINATE_KEYS: - tmp_coord_values = _apply_coordinate_scalar(tmp_coord_values, spatial_coordinate_scalar) - - # Assign to dataset - dataset[coord_name][:] = tmp_coord_values - drop_vars_delayed.append(coord_name) - - # Explicitly release intermediate arrays - del tmp_coord_values, coord_trace_indices, not_null, coord_values - - # TODO(Altay): Add verification of reduced coordinates being the same as the first - # https://github.com/TGSAI/mdio-python/issues/645 - - return dataset, drop_vars_delayed - - -def _get_spatial_coordinate_unit(segy_file_info: SegyFileInfo) -> LengthUnitModel | None: - """Get the coordinate unit from the SEG-Y headers.""" - measurement_system_code = int(segy_file_info.binary_header_dict[MEASUREMENT_SYSTEM_KEY]) - - if measurement_system_code not in (1, 2): - logger.warning( - "Unexpected value in coordinate unit (%s) header: %s. Can't extract coordinate unit and will " - "ingest without coordinate units.", - MEASUREMENT_SYSTEM_KEY, - measurement_system_code, - ) - return None - - if measurement_system_code == SegyMeasurementSystem.METERS: - unit = LengthUnitEnum.METER - if measurement_system_code == SegyMeasurementSystem.FEET: - unit = LengthUnitEnum.FOOT - - return LengthUnitModel(length=unit) - - -def _update_template_units(template: AbstractDatasetTemplate, unit: LengthUnitModel | None) -> AbstractDatasetTemplate: - """Update the template with dynamic and some pre-defined units.""" - # Add units for pre-defined: angle and azimuth etc. - new_units = {key: AngleUnitModel(angle=AngleUnitEnum.DEGREES) for key in ANGLE_UNIT_KEYS} - - # If a spatial unit is not provided, we return as is - if unit is None: - template.add_units(new_units) - return template - - # Dynamically add units based on the spatial coordinate unit - for key in SPATIAL_UNIT_KEYS: - current_value = template.get_unit_by_key(key) - if current_value is not None: - logger.warning("Unit for %s already in template. Will keep the original unit: %s", key, current_value) - continue - - new_units[key] = unit - - template.add_units(new_units) - return template - - -def _populate_coordinates( - dataset: xr_Dataset, - grid: Grid, - coords: dict[str, SegyHeaderArray], - spatial_coordinate_scalar: int, -) -> tuple[xr_Dataset, list[str]]: - """Populate dim and non-dim coordinates in the xarray dataset and write to Zarr. - - This will write the xr Dataset with coords and dimensions, but empty traces and headers. - - Args: - dataset: The xarray dataset to populate. - grid: The grid object containing the grid map. - coords: The non-dim coordinates to populate. - spatial_coordinate_scalar: The X/Y coordinate scalar from the SEG-Y file. - - Returns: - Xarray dataset with filled coordinates and updated variables to drop after writing - """ - drop_vars_delayed = [] - # Populate the dimension coordinate variables (1-D arrays) - dataset, drop_vars_delayed = populate_dim_coordinates(dataset, grid, drop_vars_delayed=drop_vars_delayed) - - # Populate the non-dimension coordinate variables (N-dim arrays) - dataset, drop_vars_delayed = populate_non_dim_coordinates( - dataset, - grid, - coordinates=coords, - drop_vars_delayed=drop_vars_delayed, - spatial_coordinate_scalar=spatial_coordinate_scalar, - ) - - return dataset, drop_vars_delayed - - -def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_info: SegyFileInfo) -> xr_Dataset: - settings = MDIOSettings() - - if not settings.save_segy_file_header: - return xr_dataset - - expected_rows = 40 - expected_cols = 80 - - text_header_rows = segy_file_info.text_header.splitlines() - text_header_cols_bad = [len(row) != expected_cols for row in text_header_rows] - - if len(text_header_rows) != expected_rows: - err = f"Invalid text header count: expected {expected_rows}, got {len(segy_file_info.text_header)}" - raise ValueError(err) - - if any(text_header_cols_bad): - err = f"Invalid text header columns: expected {expected_cols} per line." - raise ValueError(err) - - xr_dataset["segy_file_header"] = ((), "") - xr_dataset["segy_file_header"].attrs.update( - { - "textHeader": segy_file_info.text_header, - "binaryHeader": segy_file_info.binary_header_dict, - } - ) - if settings.raw_headers: - raw_binary_base64 = base64.b64encode(segy_file_info.raw_binary_headers).decode("ascii") - xr_dataset["segy_file_header"].attrs.update({"rawBinaryHeader": raw_binary_base64}) - - return xr_dataset - - -def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: dict[str, Any] | None) -> None: - """Add grid override to Dataset metadata if needed.""" - if dataset.metadata.attributes is None: - dataset.metadata.attributes = {} - - if grid_overrides is not None: - dataset.metadata.attributes["gridOverrides"] = grid_overrides - - def _add_raw_headers_to_template(mdio_template: AbstractDatasetTemplate) -> AbstractDatasetTemplate: """Add raw headers capability to the MDIO template by monkey-patching its _add_variables method. @@ -659,31 +348,6 @@ def determine_target_size(var_type: str) -> int: ds.variables[index].metadata.chunk_grid = chunk_grid -def _validate_spec_in_template(segy_spec: SegySpec, mdio_template: AbstractDatasetTemplate) -> None: - """Validate that the SegySpec has all required fields in the MDIO template.""" - header_fields = {field.name for field in segy_spec.trace.header.fields} - - required_fields = set(mdio_template.spatial_dimension_names) | set(mdio_template.coordinate_names) - required_fields = required_fields - set(mdio_template.calculated_dimension_names) # remove to be calculated - - # For OBN template: 'component' is optional (will be synthesized if missing) - # Import here to avoid circular imports at module load time - from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 - - if isinstance(mdio_template, Seismic3DObnReceiverGathersTemplate): - required_fields.discard("component") - - required_fields = required_fields | {"coordinate_scalar"} # ensure coordinate scalar is always present - missing_fields = required_fields - header_fields - - if missing_fields: - err = ( - f"Required fields {sorted(missing_fields)} for template {mdio_template.name} " - f"not found in the provided segy_spec" - ) - raise ValueError(err) - - def segy_to_mdio( # noqa PLR0913 segy_spec: SegySpec, mdio_template: AbstractDatasetTemplate, diff --git a/src/mdio/ingestion/__init__.py b/src/mdio/ingestion/__init__.py new file mode 100644 index 00000000..00baad73 --- /dev/null +++ b/src/mdio/ingestion/__init__.py @@ -0,0 +1 @@ +"""MDIO ingestion helpers.""" diff --git a/src/mdio/ingestion/coordinates.py b/src/mdio/ingestion/coordinates.py new file mode 100644 index 00000000..9bc0a04b --- /dev/null +++ b/src/mdio/ingestion/coordinates.py @@ -0,0 +1,77 @@ +"""Generic coordinate population for ingestion.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from mdio.segy.scalar import SCALE_COORDINATE_KEYS +from mdio.segy.scalar import _apply_coordinate_scalar + +if TYPE_CHECKING: + from segy.arrays import HeaderArray as SegyHeaderArray + from xarray import Dataset as xr_Dataset + + from mdio.core.grid import Grid + + +def populate_dim_coordinates( + dataset: xr_Dataset, grid: Grid, drop_vars_delayed: list[str] +) -> tuple[xr_Dataset, list[str]]: + """Populate the xarray dataset with dimension coordinate variables.""" + for dim in grid.dims: + dataset[dim.name].values[:] = dim.coords + drop_vars_delayed.append(dim.name) + return dataset, drop_vars_delayed + + +def populate_non_dim_coordinates( + dataset: xr_Dataset, + grid: Grid, + coordinates: dict[str, SegyHeaderArray], + drop_vars_delayed: list[str], + spatial_coordinate_scalar: int, +) -> tuple[xr_Dataset, list[str]]: + """Populate the xarray dataset with coordinate variables. + + Memory optimization: Processes coordinates one at a time and explicitly + releases intermediate arrays to reduce peak memory usage. + """ + non_data_domain_dims = grid.dim_names[:-1] + + coord_names = list(coordinates.keys()) + for coord_name in coord_names: + coord_values = coordinates.pop(coord_name) + da_coord = dataset[coord_name] + + coord_shape = da_coord.shape + + fill_value = da_coord.encoding.get("_FillValue") or da_coord.encoding.get("fill_value") + if fill_value is None: + fill_value = np.nan + tmp_coord_values = np.full(coord_shape, fill_value, dtype=da_coord.dtype) + + coord_axes = tuple(non_data_domain_dims.index(coord_dim) for coord_dim in da_coord.dims) + coord_slices = tuple(slice(None) if idx in coord_axes else 0 for idx in range(len(non_data_domain_dims))) + + coord_trace_indices = np.asarray(grid.map[coord_slices]) + + not_null = coord_trace_indices != grid.map.fill_value + + if not_null.any(): + valid_indices = coord_trace_indices[not_null] + tmp_coord_values[not_null] = coord_values[valid_indices] + + if coord_name in SCALE_COORDINATE_KEYS: + tmp_coord_values = _apply_coordinate_scalar(tmp_coord_values, spatial_coordinate_scalar) + + dataset[coord_name][:] = tmp_coord_values + drop_vars_delayed.append(coord_name) + + del tmp_coord_values, coord_trace_indices, not_null, coord_values + + # TODO(Altay): Add verification of reduced coordinates being the same as the first + # https://github.com/TGSAI/mdio-python/issues/645 + + return dataset, drop_vars_delayed diff --git a/src/mdio/ingestion/grid_qc.py b/src/mdio/ingestion/grid_qc.py new file mode 100644 index 00000000..65ea05e3 --- /dev/null +++ b/src/mdio/ingestion/grid_qc.py @@ -0,0 +1,69 @@ +"""Grid sparsity quality control for ingestion.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import numpy as np + +from mdio.converters.exceptions import GridTraceSparsityError +from mdio.core.config import MDIOSettings + +if TYPE_CHECKING: + from mdio.core.grid import Grid + +logger = logging.getLogger(__name__) + + +def grid_density_qc(grid: Grid, num_traces: int) -> None: + """Quality control for sensible grid density during SEG-Y to MDIO conversion. + + This function checks the density of the proposed grid by comparing the total possible traces + (`grid_traces`) to the actual number of traces in the SEG-Y file (`num_traces`). A warning is + logged if the sparsity ratio (`grid_traces / num_traces`) exceeds a configurable threshold, + indicating potential inefficiency or misconfiguration. + + The warning threshold is set via the environment variable `MDIO__GRID__SPARSITY_RATIO_WARN` + (default 2), and the error threshold via `MDIO__GRID__SPARSITY_RATIO_LIMIT` (default 10). To + suppress the exception (but still log warnings), set `MDIO_IGNORE_CHECKS=1`. + + Args: + grid: The Grid instance to check. + num_traces: Expected number of traces from the SEG-Y file. + + Raises: + GridTraceSparsityError: If the sparsity ratio exceeds `MDIO__GRID__SPARSITY_RATIO_LIMIT` + and `MDIO_IGNORE_CHECKS` is not set to a truthy value (e.g., "1", "true"). + """ + settings = MDIOSettings() + grid_traces = np.prod(grid.shape[:-1], dtype=np.uint64) + + sparsity_ratio = float("inf") if num_traces == 0 else grid_traces / num_traces + + warning_ratio = settings.grid_sparsity_ratio_warn + error_ratio = settings.grid_sparsity_ratio_limit + ignore_checks = settings.ignore_checks + + should_warn = sparsity_ratio > warning_ratio + should_error = sparsity_ratio > error_ratio and not ignore_checks + + if not should_warn and not should_error: + return + + dims = dict(zip(grid.dim_names, grid.shape, strict=True)) + msg = ( + f"Ingestion grid is sparse. Sparsity ratio: {sparsity_ratio:.2f}. " + f"Ingestion grid: {dims}. " + f"SEG-Y trace count: {num_traces}, grid trace count: {grid_traces}." + ) + for dim_name in grid.dim_names: + dim_min = grid.get_min(dim_name) + dim_max = grid.get_max(dim_name) + msg += f"\n{dim_name} min: {dim_min} max: {dim_max}" + + if should_warn: + logger.warning(msg) + + if should_error: + raise GridTraceSparsityError(grid.shape, num_traces, msg) diff --git a/src/mdio/ingestion/metadata.py b/src/mdio/ingestion/metadata.py new file mode 100644 index 00000000..674e91eb --- /dev/null +++ b/src/mdio/ingestion/metadata.py @@ -0,0 +1,18 @@ +"""Generic dataset metadata helpers for ingestion.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any + +if TYPE_CHECKING: + from mdio.builder.schemas import Dataset + + +def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: dict[str, Any] | None) -> None: + """Add grid override to Dataset metadata if needed.""" + if dataset.metadata.attributes is None: + dataset.metadata.attributes = {} + + if grid_overrides is not None: + dataset.metadata.attributes["gridOverrides"] = grid_overrides diff --git a/src/mdio/ingestion/segy/__init__.py b/src/mdio/ingestion/segy/__init__.py new file mode 100644 index 00000000..f50f0f37 --- /dev/null +++ b/src/mdio/ingestion/segy/__init__.py @@ -0,0 +1 @@ +"""SEG-Y specific ingestion helpers.""" diff --git a/src/mdio/ingestion/segy/coordinates.py b/src/mdio/ingestion/segy/coordinates.py new file mode 100644 index 00000000..5deceef2 --- /dev/null +++ b/src/mdio/ingestion/segy/coordinates.py @@ -0,0 +1,157 @@ +"""Coordinate extraction and unit resolution for SEG-Y ingestion.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import numpy as np +from segy.standards.codes import MeasurementSystem as SegyMeasurementSystem +from segy.standards.fields import binary as binary_header_fields + +from mdio.builder.schemas.v1.units import AngleUnitEnum +from mdio.builder.schemas.v1.units import AngleUnitModel +from mdio.builder.schemas.v1.units import LengthUnitEnum +from mdio.builder.schemas.v1.units import LengthUnitModel +from mdio.ingestion.coordinates import populate_dim_coordinates +from mdio.ingestion.coordinates import populate_non_dim_coordinates + +if TYPE_CHECKING: + from segy.arrays import HeaderArray as SegyHeaderArray + from xarray import Dataset as xr_Dataset + + from mdio.builder.templates.base import AbstractDatasetTemplate + from mdio.core.dimension import Dimension + from mdio.core.grid import Grid + from mdio.segy.file import SegyFileInfo + +logger = logging.getLogger(__name__) + + +MEASUREMENT_SYSTEM_KEY = binary_header_fields.Rev0.MEASUREMENT_SYSTEM_CODE.model.name +ANGLE_UNIT_KEYS = ["angle", "azimuth"] +SPATIAL_UNIT_KEYS = [ + "cdp_x", + "cdp_y", + "source_coord_x", + "source_coord_y", + "group_coord_x", + "group_coord_y", + "offset", +] + + +def _get_coordinates( + grid: Grid, + segy_headers: SegyHeaderArray, + mdio_template: AbstractDatasetTemplate, +) -> tuple[list[Dimension], dict[str, SegyHeaderArray]]: + """Get the data dim and non-dim coordinates from the SEG-Y headers and MDIO template. + + Select a subset of the segy_dimensions that corresponds to the MDIO dimensions + The dimensions are ordered as in the MDIO template. + The last dimension is always the vertical domain dimension + + Args: + grid: Inferred MDIO grid for SEG-Y file. + segy_headers: Headers read in from SEG-Y file. + mdio_template: The MDIO template to use for the conversion. + + Raises: + ValueError: If a dimension or coordinate name from the MDIO template is not found in + the SEG-Y headers. + + Returns: + A tuple containing: + - A list of dimension coordinates (1-D arrays). + - A dict of non-dimension coordinates (str: N-D arrays). + """ + dimensions_coords = [] + for dim_name in mdio_template.dimension_names: + if dim_name not in grid.dim_names: + err = f"Dimension '{dim_name}' was not found in SEG-Y dimensions." + raise ValueError(err) + dimensions_coords.append(grid.select_dim(dim_name)) + + non_dim_coords: dict[str, SegyHeaderArray] = {} + for coord_name in mdio_template.coordinate_names: + if coord_name not in segy_headers.dtype.names: + err = f"Coordinate '{coord_name}' not found in SEG-Y dimensions." + raise ValueError(err) + non_dim_coords[coord_name] = np.array(segy_headers[coord_name]) + + return dimensions_coords, non_dim_coords + + +def _populate_coordinates( + dataset: xr_Dataset, + grid: Grid, + coords: dict[str, SegyHeaderArray], + spatial_coordinate_scalar: int, +) -> tuple[xr_Dataset, list[str]]: + """Populate dim and non-dim coordinates in the xarray dataset and write to Zarr. + + This will write the xr Dataset with coords and dimensions, but empty traces and headers. + + Args: + dataset: The xarray dataset to populate. + grid: The grid object containing the grid map. + coords: The non-dim coordinates to populate. + spatial_coordinate_scalar: The X/Y coordinate scalar from the SEG-Y file. + + Returns: + Xarray dataset with filled coordinates and updated variables to drop after writing + """ + drop_vars_delayed = [] + dataset, drop_vars_delayed = populate_dim_coordinates(dataset, grid, drop_vars_delayed=drop_vars_delayed) + + dataset, drop_vars_delayed = populate_non_dim_coordinates( + dataset, + grid, + coordinates=coords, + drop_vars_delayed=drop_vars_delayed, + spatial_coordinate_scalar=spatial_coordinate_scalar, + ) + + return dataset, drop_vars_delayed + + +def _get_spatial_coordinate_unit(segy_file_info: SegyFileInfo) -> LengthUnitModel | None: + """Get the coordinate unit from the SEG-Y headers.""" + measurement_system_code = int(segy_file_info.binary_header_dict[MEASUREMENT_SYSTEM_KEY]) + + if measurement_system_code not in (1, 2): + logger.warning( + "Unexpected value in coordinate unit (%s) header: %s. Can't extract coordinate unit and will " + "ingest without coordinate units.", + MEASUREMENT_SYSTEM_KEY, + measurement_system_code, + ) + return None + + if measurement_system_code == SegyMeasurementSystem.METERS: + unit = LengthUnitEnum.METER + if measurement_system_code == SegyMeasurementSystem.FEET: + unit = LengthUnitEnum.FOOT + + return LengthUnitModel(length=unit) + + +def _update_template_units(template: AbstractDatasetTemplate, unit: LengthUnitModel | None) -> AbstractDatasetTemplate: + """Update the template with dynamic and some pre-defined units.""" + new_units = {key: AngleUnitModel(angle=AngleUnitEnum.DEGREES) for key in ANGLE_UNIT_KEYS} + + if unit is None: + template.add_units(new_units) + return template + + for key in SPATIAL_UNIT_KEYS: + current_value = template.get_unit_by_key(key) + if current_value is not None: + logger.warning("Unit for %s already in template. Will keep the original unit: %s", key, current_value) + continue + + new_units[key] = unit + + template.add_units(new_units) + return template diff --git a/src/mdio/ingestion/segy/file_headers.py b/src/mdio/ingestion/segy/file_headers.py new file mode 100644 index 00000000..367fa793 --- /dev/null +++ b/src/mdio/ingestion/segy/file_headers.py @@ -0,0 +1,48 @@ +"""Attach SEG-Y text and binary file headers to the dataset.""" + +from __future__ import annotations + +import base64 +from typing import TYPE_CHECKING + +from mdio.core.config import MDIOSettings + +if TYPE_CHECKING: + from xarray import Dataset as xr_Dataset + + from mdio.segy.file import SegyFileInfo + + +def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_info: SegyFileInfo) -> xr_Dataset: + """Attach the SEG-Y text and binary file headers as attrs on a scalar variable.""" + settings = MDIOSettings() + + if not settings.save_segy_file_header: + return xr_dataset + + expected_rows = 40 + expected_cols = 80 + + text_header_rows = segy_file_info.text_header.splitlines() + text_header_cols_bad = [len(row) != expected_cols for row in text_header_rows] + + if len(text_header_rows) != expected_rows: + err = f"Invalid text header count: expected {expected_rows}, got {len(segy_file_info.text_header)}" + raise ValueError(err) + + if any(text_header_cols_bad): + err = f"Invalid text header columns: expected {expected_cols} per line." + raise ValueError(err) + + xr_dataset["segy_file_header"] = ((), "") + xr_dataset["segy_file_header"].attrs.update( + { + "textHeader": segy_file_info.text_header, + "binaryHeader": segy_file_info.binary_header_dict, + } + ) + if settings.raw_headers: + raw_binary_base64 = base64.b64encode(segy_file_info.raw_binary_headers).decode("ascii") + xr_dataset["segy_file_header"].attrs.update({"rawBinaryHeader": raw_binary_base64}) + + return xr_dataset diff --git a/src/mdio/ingestion/segy/header_analysis.py b/src/mdio/ingestion/segy/header_analysis.py new file mode 100644 index 00000000..7aecf338 --- /dev/null +++ b/src/mdio/ingestion/segy/header_analysis.py @@ -0,0 +1,275 @@ +"""SEG-Y header analysis primitives for acquisition-geometry detection.""" + +from __future__ import annotations + +import logging +import time +from enum import Enum +from enum import auto +from typing import TYPE_CHECKING + +import numpy as np +from numpy.lib import recfunctions as rfn + +if TYPE_CHECKING: + from numpy.typing import DTypeLike + from numpy.typing import NDArray + from segy.arrays import HeaderArray + + +logger = logging.getLogger(__name__) + + +class StreamerShotGeometryType(Enum): + r"""Shot geometry template types for streamer acquisition. + + Configuration A: + Cable 1 -> 1------------------20 + Cable 2 -> 1-----------------20 + . 1-----------------20 + . ⛴ ☆ 1-----------------20 + . 1-----------------20 + Cable 6 -> 1-----------------20 + Cable 7 -> 1-----------------20 + + + Configuration B: + Cable 1 -> 1------------------20 + Cable 2 -> 21-----------------40 + . 41-----------------60 + . ⛴ ☆ 61-----------------80 + . 81----------------100 + Cable 6 -> 101---------------120 + Cable 7 -> 121---------------140 + + Configuration C: + Cable ? -> / 1------------------20 + Cable ? -> / 21-----------------40 + . / 41-----------------60 + . ⛴ ☆ - 61-----------------80 + . \ 81----------------100 + Cable ? -> \ 101---------------120 + Cable ? -> \ 121---------------140 + """ + + A = auto() + B = auto() + C = auto() + + +class ShotGunGeometryType(Enum): + r"""Shot geometry template types for multi-gun acquisition. + + For shot lines with multiple guns, we can have two configurations for numbering shot_point. The + desired index is to have the shot point index for a given gun to be dense and unique + (configuration A). Typically the shot_point is unique for the line and therefore is not dense + for each gun (configuration B). + + Configuration A: + Gun 1 -> 1------------------20 + Gun 2 -> 1------------------20 + + Configuration B: + Gun 1 -> 1------------------39 + Gun 2 -> 2------------------40 + + """ + + A = auto() + B = auto() + + +def analyze_streamer_headers( + index_headers: HeaderArray, +) -> tuple[NDArray, NDArray, NDArray, StreamerShotGeometryType]: + """Check input headers for SEG-Y input to help determine geometry. + + This function reads in trace_qc_count headers and finds the unique cable values. The function + then checks to ensure channel numbers for different cables do not overlap. + + Args: + index_headers: numpy array with index headers + + Returns: + tuple of unique_cables, cable_chan_min, cable_chan_max, geom_type + """ + unique_cables = np.sort(np.unique(index_headers["cable"])) + + cable_chan_min = np.empty(unique_cables.shape) + cable_chan_max = np.empty(unique_cables.shape) + + for idx, cable in enumerate(unique_cables): + cable_mask = index_headers["cable"] == cable + current_cable = index_headers["channel"][cable_mask] + + cable_chan_min[idx] = np.min(current_cable) + cable_chan_max[idx] = np.max(current_cable) + + geom_type = StreamerShotGeometryType.B + + for idx1, cable1 in enumerate(unique_cables): + min_val1 = cable_chan_min[idx1] + max_val1 = cable_chan_max[idx1] + + cable1_range = (min_val1, max_val1) + for idx2, cable2 in enumerate(unique_cables): + if cable2 == cable1: + continue + + min_val2 = cable_chan_min[idx2] + max_val2 = cable_chan_max[idx2] + cable2_range = (min_val2, max_val2) + + if min_val2 < max_val1 and max_val2 > min_val1: + geom_type = StreamerShotGeometryType.A + + logger.info("Found overlapping channels, assuming streamer type A") + overlap_info = ( + "Cable %s index %s with channel range %s overlaps cable %s index %s with " + "channel range %s. Check for aux trace issues if the overlap is unexpected. " + "To fix, modify the SEG-Y file or use AutoIndex grid override (not channel) " + "for channel number correction." + ) + logger.info(overlap_info, cable1, idx1, cable1_range, cable2, idx2, cable2_range) + + return unique_cables, cable_chan_min, cable_chan_max, geom_type + + return unique_cables, cable_chan_min, cable_chan_max, geom_type + + +def analyze_lines_for_guns( + index_headers: HeaderArray, + line_field: str = "sail_line", +) -> tuple[NDArray, dict[str, list], ShotGunGeometryType]: + """Check input headers for SEG-Y input to help determine geometry of shots and guns. + + This is a generalized function that works with any line field name (sail_line, shot_line, etc.) + to analyze multi-gun acquisition geometry and determine if shot points are interleaved. + + Args: + index_headers: Numpy array with index headers. + line_field: Name of the line field to use (e.g., 'sail_line', 'shot_line'). + + Returns: + tuple of (unique_lines, unique_guns_per_line, geom_type) + """ + unique_lines = np.sort(np.unique(index_headers[line_field])) + unique_guns = np.sort(np.unique(index_headers["gun"])) + logger.info("unique_%s values: %s", line_field, unique_lines) + logger.info("unique_guns: %s", unique_guns) + + unique_guns_per_line = {} + + geom_type = ShotGunGeometryType.B + for line_val in unique_lines: + line_mask = index_headers[line_field] == line_val + shot_current = index_headers["shot_point"][line_mask] + gun_current = index_headers["gun"][line_mask] + + unique_guns_in_line = np.sort(np.unique(gun_current)) + num_guns = unique_guns_in_line.shape[0] + unique_guns_per_line[str(line_val)] = list(unique_guns_in_line) + + if geom_type == ShotGunGeometryType.A: + continue + + for gun in unique_guns_in_line: + gun_mask = gun_current == gun + shots_for_gun = shot_current[gun_mask] + num_shots = np.unique(shots_for_gun).shape[0] + mod_shots = np.floor(shots_for_gun / num_guns) + if len(np.unique(mod_shots)) != num_shots: + msg = "%s %s has %s shots; div by %s guns gives %s unique mod shots." + logger.info(msg, line_field, line_val, num_shots, num_guns, len(np.unique(mod_shots))) + geom_type = ShotGunGeometryType.A + break + + return unique_lines, unique_guns_per_line, geom_type + + +def analyze_saillines_for_guns( + index_headers: HeaderArray, +) -> tuple[NDArray, dict[str, list], ShotGunGeometryType]: + """Analyze sail lines for gun geometry. See analyze_lines_for_guns for details.""" + return analyze_lines_for_guns(index_headers, line_field="sail_line") + + +def create_counter( + depth: int, + total_depth: int, + unique_headers: dict[str, NDArray], + header_names: list[str], +) -> dict[str, dict]: + """Helper function to create dictionary tree for counting trace key for auto index.""" + if depth == total_depth: + return 0 + + counter = {} + + header_key = header_names[depth] + for header in unique_headers[header_key]: + counter[header] = create_counter(depth + 1, total_depth, unique_headers, header_names) + + return counter + + +def create_trace_index( + depth: int, + counter: dict, + index_headers: HeaderArray, + header_names: list[str], + dtype: DTypeLike = np.int16, +) -> NDArray | None: + """Update dictionary counter tree for counting trace key for auto index.""" + if depth == 0: + return None + + trace_no_field = np.zeros(index_headers.shape, dtype=dtype) + index_headers = rfn.append_fields(index_headers, "trace", trace_no_field, usemask=False) + + headers = [index_headers[name] for name in header_names[:depth]] + for idx, idx_values in enumerate(zip(*headers, strict=True)): + if depth == 1: + counter[idx_values[0]] += 1 + index_headers["trace"][idx] = counter[idx_values[0]] + else: + sub_counter = counter + for idx_value in idx_values[:-1]: + sub_counter = sub_counter[idx_value] + sub_counter[idx_values[-1]] += 1 + index_headers["trace"][idx] = sub_counter[idx_values[-1]] + + return index_headers + + +def analyze_non_indexed_headers(index_headers: HeaderArray, dtype: DTypeLike = np.int16) -> NDArray: + """Check input headers for SEG-Y input to help determine geometry. + + This function reads in trace_qc_count headers and finds the unique cable values. Then, it + checks to make sure channel numbers for different cables do not overlap. + + Args: + index_headers: numpy array with index headers + dtype: numpy type for value of created trace header. + + Returns: + Dict container header name as key and numpy array of values as value + """ + t_start = time.perf_counter() + unique_headers = {} + total_depth = 0 + header_names = [] + for header_key in index_headers.dtype.names: + if header_key != "trace": + unique_vals = np.sort(np.unique(index_headers[header_key])) + unique_headers[header_key] = unique_vals + header_names.append(header_key) + total_depth += 1 + + counter = create_counter(0, total_depth, unique_headers, header_names) + + index_headers = create_trace_index(total_depth, counter, index_headers, header_names, dtype=dtype) + + t_stop = time.perf_counter() + logger.debug("Time spent generating trace index: %.4f s", t_start - t_stop) + return index_headers diff --git a/src/mdio/ingestion/segy/validation.py b/src/mdio/ingestion/segy/validation.py new file mode 100644 index 00000000..1c74c96f --- /dev/null +++ b/src/mdio/ingestion/segy/validation.py @@ -0,0 +1,35 @@ +"""SegySpec/template validation for SEG-Y ingestion.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from segy.schema import SegySpec + + from mdio.builder.templates.base import AbstractDatasetTemplate + + +def _validate_spec_in_template(segy_spec: SegySpec, mdio_template: AbstractDatasetTemplate) -> None: + """Validate that the SegySpec has all required fields in the MDIO template.""" + # Import here to avoid circular imports at module load time + from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 + + header_fields = {field.name for field in segy_spec.trace.header.fields} + + required_fields = set(mdio_template.spatial_dimension_names) | set(mdio_template.coordinate_names) + required_fields = required_fields - set(mdio_template.calculated_dimension_names) + + # 'component' is optional for OBN (synthesized if missing) + if isinstance(mdio_template, Seismic3DObnReceiverGathersTemplate): + required_fields.discard("component") + + required_fields = required_fields | {"coordinate_scalar"} + missing_fields = required_fields - header_fields + + if missing_fields: + err = ( + f"Required fields {sorted(missing_fields)} for template {mdio_template.name} " + f"not found in the provided segy_spec" + ) + raise ValueError(err) diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index 16ec5cd1..a078db99 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -3,16 +3,22 @@ from __future__ import annotations import logging -import time from abc import ABC from abc import abstractmethod -from enum import Enum -from enum import auto from typing import TYPE_CHECKING import numpy as np from numpy.lib import recfunctions as rfn +# Re-exported for backward compatibility with `from mdio.segy.geometry import ...` callers. +from mdio.ingestion.segy.header_analysis import ShotGunGeometryType # noqa: F401 +from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType # noqa: F401 +from mdio.ingestion.segy.header_analysis import analyze_lines_for_guns +from mdio.ingestion.segy.header_analysis import analyze_non_indexed_headers +from mdio.ingestion.segy.header_analysis import analyze_saillines_for_guns # noqa: F401 +from mdio.ingestion.segy.header_analysis import analyze_streamer_headers +from mdio.ingestion.segy.header_analysis import create_counter # noqa: F401 +from mdio.ingestion.segy.header_analysis import create_trace_index # noqa: F401 from mdio.segy.exceptions import GridOverrideKeysError from mdio.segy.exceptions import GridOverrideMissingParameterError from mdio.segy.exceptions import GridOverrideUnknownError @@ -20,7 +26,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from numpy.typing import DTypeLike from numpy.typing import NDArray from segy.arrays import HeaderArray @@ -30,272 +35,6 @@ logger = logging.getLogger(__name__) -class StreamerShotGeometryType(Enum): - r"""Shot geometry template types for streamer acquisition. - - Configuration A: - Cable 1 -> 1------------------20 - Cable 2 -> 1-----------------20 - . 1-----------------20 - . ⛴ ☆ 1-----------------20 - . 1-----------------20 - Cable 6 -> 1-----------------20 - Cable 7 -> 1-----------------20 - - - Configuration B: - Cable 1 -> 1------------------20 - Cable 2 -> 21-----------------40 - . 41-----------------60 - . ⛴ ☆ 61-----------------80 - . 81----------------100 - Cable 6 -> 101---------------120 - Cable 7 -> 121---------------140 - - Configuration C: - Cable ? -> / 1------------------20 - Cable ? -> / 21-----------------40 - . / 41-----------------60 - . ⛴ ☆ - 61-----------------80 - . \ 81----------------100 - Cable ? -> \ 101---------------120 - Cable ? -> \ 121---------------140 - """ - - A = auto() - B = auto() - C = auto() - - -class ShotGunGeometryType(Enum): - r"""Shot geometry template types for multi-gun acquisition. - - For shot lines with multiple guns, we can have two configurations for numbering shot_point. The - desired index is to have the shot point index for a given gun to be dense and unique - (configuration A). Typically the shot_point is unique for the line and therefore is not dense - for each gun (configuration B). - - Configuration A: - Gun 1 -> 1------------------20 - Gun 2 -> 1------------------20 - - Configuration B: - Gun 1 -> 1------------------39 - Gun 2 -> 2------------------40 - - """ - - A = auto() - B = auto() - - -def analyze_streamer_headers( - index_headers: HeaderArray, -) -> tuple[NDArray, NDArray, NDArray, StreamerShotGeometryType]: - """Check input headers for SEG-Y input to help determine geometry. - - This function reads in trace_qc_count headers and finds the unique cable values. The function - then checks to ensure channel numbers for different cables do not overlap. - - Args: - index_headers: numpy array with index headers - - Returns: - tuple of unique_cables, cable_chan_min, cable_chan_max, geom_type - """ - # Find unique cable ids - unique_cables = np.sort(np.unique(index_headers["cable"])) - - # Find channel min and max values for each cable - cable_chan_min = np.empty(unique_cables.shape) - cable_chan_max = np.empty(unique_cables.shape) - - for idx, cable in enumerate(unique_cables): - cable_mask = index_headers["cable"] == cable - current_cable = index_headers["channel"][cable_mask] - - cable_chan_min[idx] = np.min(current_cable) - cable_chan_max[idx] = np.max(current_cable) - - # Check channel numbers do not overlap for case B - geom_type = StreamerShotGeometryType.B - - for idx1, cable1 in enumerate(unique_cables): - min_val1 = cable_chan_min[idx1] - max_val1 = cable_chan_max[idx1] - - cable1_range = (min_val1, max_val1) - for idx2, cable2 in enumerate(unique_cables): - if cable2 == cable1: - continue - - min_val2 = cable_chan_min[idx2] - max_val2 = cable_chan_max[idx2] - cable2_range = (min_val2, max_val2) - - # Check for overlap and return early with Type A - if min_val2 < max_val1 and max_val2 > min_val1: - geom_type = StreamerShotGeometryType.A - - logger.info("Found overlapping channels, assuming streamer type A") - overlap_info = ( - "Cable %s index %s with channel range %s overlaps cable %s index %s with " - "channel range %s. Check for aux trace issues if the overlap is unexpected. " - "To fix, modify the SEG-Y file or use AutoIndex grid override (not channel) " - "for channel number correction." - ) - logger.info(overlap_info, cable1, idx1, cable1_range, cable2, idx2, cable2_range) - - return unique_cables, cable_chan_min, cable_chan_max, geom_type - - return unique_cables, cable_chan_min, cable_chan_max, geom_type - - -def analyze_lines_for_guns( - index_headers: HeaderArray, - line_field: str = "sail_line", -) -> tuple[NDArray, dict[str, list], ShotGunGeometryType]: - """Check input headers for SEG-Y input to help determine geometry of shots and guns. - - This is a generalized function that works with any line field name (sail_line, shot_line, etc.) - to analyze multi-gun acquisition geometry and determine if shot points are interleaved. - - Args: - index_headers: Numpy array with index headers. - line_field: Name of the line field to use (e.g., 'sail_line', 'shot_line'). - - Returns: - tuple of (unique_lines, unique_guns_per_line, geom_type) - """ - unique_lines = np.sort(np.unique(index_headers[line_field])) - unique_guns = np.sort(np.unique(index_headers["gun"])) - logger.info("unique_%s values: %s", line_field, unique_lines) - logger.info("unique_guns: %s", unique_guns) - - unique_guns_per_line = {} - - geom_type = ShotGunGeometryType.B - # Check shot numbers are still unique if div/num_guns - for line_val in unique_lines: - line_mask = index_headers[line_field] == line_val - shot_current = index_headers["shot_point"][line_mask] - gun_current = index_headers["gun"][line_mask] - - unique_guns_in_line = np.sort(np.unique(gun_current)) - num_guns = unique_guns_in_line.shape[0] - unique_guns_per_line[str(line_val)] = list(unique_guns_in_line) - - # Skip geometry detection if we already know it's Type A - if geom_type == ShotGunGeometryType.A: - continue - - for gun in unique_guns_in_line: - gun_mask = gun_current == gun - shots_for_gun = shot_current[gun_mask] - num_shots = np.unique(shots_for_gun).shape[0] - mod_shots = np.floor(shots_for_gun / num_guns) - if len(np.unique(mod_shots)) != num_shots: - msg = "%s %s has %s shots; div by %s guns gives %s unique mod shots." - logger.info(msg, line_field, line_val, num_shots, num_guns, len(np.unique(mod_shots))) - geom_type = ShotGunGeometryType.A - break # No need to check more guns for this line - - return unique_lines, unique_guns_per_line, geom_type - - -# Backward-compatible aliases for existing code -def analyze_saillines_for_guns( - index_headers: HeaderArray, -) -> tuple[NDArray, dict[str, list], ShotGunGeometryType]: - """Analyze sail lines for gun geometry. See analyze_lines_for_guns for details.""" - return analyze_lines_for_guns(index_headers, line_field="sail_line") - - -def create_counter( - depth: int, - total_depth: int, - unique_headers: dict[str, NDArray], - header_names: list[str], -) -> dict[str, dict]: - """Helper function to create dictionary tree for counting trace key for auto index.""" - if depth == total_depth: - return 0 - - counter = {} - - header_key = header_names[depth] - for header in unique_headers[header_key]: - counter[header] = create_counter(depth + 1, total_depth, unique_headers, header_names) - - return counter - - -def create_trace_index( - depth: int, - counter: dict, - index_headers: HeaderArray, - header_names: list[str], - dtype: DTypeLike = np.int16, -) -> NDArray | None: - """Update dictionary counter tree for counting trace key for auto index.""" - if depth == 0: - # If there's no hierarchical depth, no tracing needed. - return None - - # Add index header - trace_no_field = np.zeros(index_headers.shape, dtype=dtype) - index_headers = rfn.append_fields(index_headers, "trace", trace_no_field, usemask=False) - - # Extract the relevant columns upfront - headers = [index_headers[name] for name in header_names[:depth]] - for idx, idx_values in enumerate(zip(*headers, strict=True)): - if depth == 1: - counter[idx_values[0]] += 1 - index_headers["trace"][idx] = counter[idx_values[0]] - else: - sub_counter = counter - for idx_value in idx_values[:-1]: - sub_counter = sub_counter[idx_value] - sub_counter[idx_values[-1]] += 1 - index_headers["trace"][idx] = sub_counter[idx_values[-1]] - - return index_headers - - -def analyze_non_indexed_headers(index_headers: HeaderArray, dtype: DTypeLike = np.int16) -> NDArray: - """Check input headers for SEG-Y input to help determine geometry. - - This function reads in trace_qc_count headers and finds the unique cable values. Then, it - checks to make sure channel numbers for different cables do not overlap. - - Args: - index_headers: numpy array with index headers - dtype: numpy type for value of created trace header. - - Returns: - Dict container header name as key and numpy array of values as value - """ - # Find unique cable ids - t_start = time.perf_counter() - unique_headers = {} - total_depth = 0 - header_names = [] - for header_key in index_headers.dtype.names: - if header_key != "trace": - unique_vals = np.sort(np.unique(index_headers[header_key])) - unique_headers[header_key] = unique_vals - header_names.append(header_key) - total_depth += 1 - - counter = create_counter(0, total_depth, unique_headers, header_names) - - index_headers = create_trace_index(total_depth, counter, index_headers, header_names, dtype=dtype) - - t_stop = time.perf_counter() - logger.debug("Time spent generating trace index: %.4f s", t_start - t_stop) - return index_headers - - class GridOverrideCommand(ABC): """Abstract base class for grid override commands.""" diff --git a/uv.lock b/uv.lock index 64a928a8..1e0aed4e 100644 --- a/uv.lock +++ b/uv.lock @@ -1181,7 +1181,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/cb/48e964c452ca2b92175a9b2dca037a553036cb053ba69e284650ce755f13/greenlet-3.3.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e29f3018580e8412d6aaf5641bb7745d38c85228dacf51a73bd4e26ddf2a6a8e", size = 274908, upload-time = "2025-12-04T14:23:26.435Z" }, { url = "https://files.pythonhosted.org/packages/28/da/38d7bff4d0277b594ec557f479d65272a893f1f2a716cad91efeb8680953/greenlet-3.3.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a687205fb22794e838f947e2194c0566d3812966b41c78709554aa883183fb62", size = 577113, upload-time = "2025-12-04T14:50:05.493Z" }, { url = "https://files.pythonhosted.org/packages/3c/f2/89c5eb0faddc3ff014f1c04467d67dee0d1d334ab81fadbf3744847f8a8a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4243050a88ba61842186cb9e63c7dfa677ec146160b0efd73b855a3d9c7fcf32", size = 590338, upload-time = "2025-12-04T14:57:41.136Z" }, - { url = "https://files.pythonhosted.org/packages/80/d7/db0a5085035d05134f8c089643da2b44cc9b80647c39e93129c5ef170d8f/greenlet-3.3.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:670d0f94cd302d81796e37299bcd04b95d62403883b24225c6b5271466612f45", size = 601098, upload-time = "2025-12-04T15:07:11.898Z" }, { url = "https://files.pythonhosted.org/packages/dc/a6/e959a127b630a58e23529972dbc868c107f9d583b5a9f878fb858c46bc1a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cb3a8ec3db4a3b0eb8a3c25436c2d49e3505821802074969db017b87bc6a948", size = 590206, upload-time = "2025-12-04T14:26:01.254Z" }, { url = "https://files.pythonhosted.org/packages/48/60/29035719feb91798693023608447283b266b12efc576ed013dd9442364bb/greenlet-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2de5a0b09eab81fc6a382791b995b1ccf2b172a9fec934747a7a23d2ff291794", size = 1550668, upload-time = "2025-12-04T15:04:22.439Z" }, { url = "https://files.pythonhosted.org/packages/0a/5f/783a23754b691bfa86bd72c3033aa107490deac9b2ef190837b860996c9f/greenlet-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4449a736606bd30f27f8e1ff4678ee193bc47f6ca810d705981cfffd6ce0d8c5", size = 1615483, upload-time = "2025-12-04T14:27:28.083Z" }, @@ -1189,7 +1188,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, - { url = "https://files.pythonhosted.org/packages/77/cb/43692bcd5f7a0da6ec0ec6d58ee7cddb606d055ce94a62ac9b1aa481e969/greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7", size = 622297, upload-time = "2025-12-04T15:07:13.552Z" }, { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, @@ -1197,7 +1195,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, - { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -1967,7 +1964,7 @@ wheels = [ [[package]] name = "multidimio" -version = "1.1.2" +version = "1.1.3" source = { editable = "." } dependencies = [ { name = "click" }, @@ -2048,7 +2045,7 @@ requires-dist = [ { name = "tqdm", specifier = ">=4.67.1" }, { name = "universal-pathlib", specifier = ">=0.3.3" }, { name = "xarray", specifier = ">=2025.10.1" }, - { name = "zarr", specifier = ">=3.1.3" }, + { name = "zarr", specifier = ">=3.1.3,<=3.1.6" }, { name = "zfpy", marker = "extra == 'lossy'", specifier = ">=1.0.1" }, ] provides-extras = ["cloud", "distributed", "lossy"] From b19188cae66bcc74bbf5a48b4d463e643494492f Mon Sep 17 00:00:00 2001 From: BrianMichell Date: Mon, 18 May 2026 20:24:21 +0000 Subject: [PATCH 3/6] Remove unnecessary compat imports --- src/mdio/converters/segy.py | 2 -- src/mdio/segy/geometry.py | 8 ++------ tests/integration/conftest.py | 2 +- tests/integration/test_import_streamer_grid_overrides.py | 2 +- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/mdio/converters/segy.py b/src/mdio/converters/segy.py index a8f0ba36..59545f3b 100644 --- a/src/mdio/converters/segy.py +++ b/src/mdio/converters/segy.py @@ -27,8 +27,6 @@ from mdio.core.utils_write import MAX_COORDINATES_BYTES from mdio.core.utils_write import MAX_SIZE_LIVE_MASK from mdio.core.utils_write import get_constrained_chunksize -from mdio.ingestion.coordinates import populate_dim_coordinates # noqa: F401 re-export for compat -from mdio.ingestion.coordinates import populate_non_dim_coordinates # noqa: F401 re-export for compat from mdio.ingestion.grid_qc import grid_density_qc from mdio.ingestion.metadata import _add_grid_override_to_metadata from mdio.ingestion.segy.coordinates import _get_coordinates diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index a078db99..cbf02be0 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -10,15 +10,11 @@ import numpy as np from numpy.lib import recfunctions as rfn -# Re-exported for backward compatibility with `from mdio.segy.geometry import ...` callers. -from mdio.ingestion.segy.header_analysis import ShotGunGeometryType # noqa: F401 -from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType # noqa: F401 +from mdio.ingestion.segy.header_analysis import ShotGunGeometryType +from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType from mdio.ingestion.segy.header_analysis import analyze_lines_for_guns from mdio.ingestion.segy.header_analysis import analyze_non_indexed_headers -from mdio.ingestion.segy.header_analysis import analyze_saillines_for_guns # noqa: F401 from mdio.ingestion.segy.header_analysis import analyze_streamer_headers -from mdio.ingestion.segy.header_analysis import create_counter # noqa: F401 -from mdio.ingestion.segy.header_analysis import create_trace_index # noqa: F401 from mdio.segy.exceptions import GridOverrideKeysError from mdio.segy.exceptions import GridOverrideMissingParameterError from mdio.segy.exceptions import GridOverrideUnknownError diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 5d14fc0d..2f17624e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -11,7 +11,7 @@ from segy.standards import SegyStandard from segy.standards import get_segy_standard -from mdio.segy.geometry import StreamerShotGeometryType +from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType if TYPE_CHECKING: from pathlib import Path diff --git a/tests/integration/test_import_streamer_grid_overrides.py b/tests/integration/test_import_streamer_grid_overrides.py index 7077ba9f..52be8f34 100644 --- a/tests/integration/test_import_streamer_grid_overrides.py +++ b/tests/integration/test_import_streamer_grid_overrides.py @@ -15,7 +15,7 @@ from mdio.builder.template_registry import TemplateRegistry from mdio.converters.exceptions import GridTraceSparsityError from mdio.converters.segy import segy_to_mdio -from mdio.segy.geometry import StreamerShotGeometryType +from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType if TYPE_CHECKING: from pathlib import Path From dc8f98d3b5c21f6b24043bc43b2c4b0856903353 Mon Sep 17 00:00:00 2001 From: BrianMichell Date: Mon, 18 May 2026 20:33:05 +0000 Subject: [PATCH 4/6] Implement unit tests for refactored helper functions --- tests/unit/ingestion/__init__.py | 1 + tests/unit/ingestion/test_coordinates.py | 291 ++++++++++++++++++ tests/unit/ingestion/test_grid_qc.py | 117 +++++++ tests/unit/ingestion/test_metadata.py | 42 +++ tests/unit/ingestion/test_segy_coordinates.py | 222 +++++++++++++ .../unit/ingestion/test_segy_file_headers.py | 94 ++++++ .../ingestion/test_segy_header_analysis.py | 242 +++++++++++++++ tests/unit/ingestion/test_segy_validation.py | 104 +++++++ tests/unit/ingestion/testing_helpers.py | 73 +++++ 9 files changed, 1186 insertions(+) create mode 100644 tests/unit/ingestion/__init__.py create mode 100644 tests/unit/ingestion/test_coordinates.py create mode 100644 tests/unit/ingestion/test_grid_qc.py create mode 100644 tests/unit/ingestion/test_metadata.py create mode 100644 tests/unit/ingestion/test_segy_coordinates.py create mode 100644 tests/unit/ingestion/test_segy_file_headers.py create mode 100644 tests/unit/ingestion/test_segy_header_analysis.py create mode 100644 tests/unit/ingestion/test_segy_validation.py create mode 100644 tests/unit/ingestion/testing_helpers.py diff --git a/tests/unit/ingestion/__init__.py b/tests/unit/ingestion/__init__.py new file mode 100644 index 00000000..10995e7a --- /dev/null +++ b/tests/unit/ingestion/__init__.py @@ -0,0 +1 @@ +"""Unit tests for ingestion helper modules.""" diff --git a/tests/unit/ingestion/test_coordinates.py b/tests/unit/ingestion/test_coordinates.py new file mode 100644 index 00000000..68f7fbf2 --- /dev/null +++ b/tests/unit/ingestion/test_coordinates.py @@ -0,0 +1,291 @@ +"""Tests for generic coordinate population helpers in ingestion.""" + +from __future__ import annotations + +import numpy as np +import pytest +from xarray import DataArray as xr_DataArray +from xarray import Dataset as xr_Dataset + +from mdio.ingestion.coordinates import populate_dim_coordinates +from mdio.ingestion.coordinates import populate_non_dim_coordinates +from tests.unit.ingestion.testing_helpers import make_grid +from tests.unit.ingestion.testing_helpers import make_grid_with_map + + +class TestPopulateDimCoordinates: + """Tests for ``populate_dim_coordinates``.""" + + def test_assigns_coords_for_each_dim(self) -> None: + """Dim coords should be copied from Grid dims onto the dataset arrays.""" + inline_coords = np.array([10, 20, 30], dtype=np.int32) + crossline_coords = np.array([100, 200], dtype=np.int32) + depth_coords = np.array([0, 4, 8, 12], dtype=np.int32) + grid = make_grid( + [ + ("inline", inline_coords), + ("crossline", crossline_coords), + ("depth", depth_coords), + ] + ) + + dataset = xr_Dataset( + { + "inline": xr_DataArray(np.zeros(3, dtype=np.int32), dims=["inline"]), + "crossline": xr_DataArray(np.zeros(2, dtype=np.int32), dims=["crossline"]), + "depth": xr_DataArray(np.zeros(4, dtype=np.int32), dims=["depth"]), + } + ) + + dataset, drop_vars = populate_dim_coordinates(dataset, grid, drop_vars_delayed=[]) + + np.testing.assert_array_equal(dataset["inline"].values, inline_coords) + np.testing.assert_array_equal(dataset["crossline"].values, crossline_coords) + np.testing.assert_array_equal(dataset["depth"].values, depth_coords) + assert drop_vars == ["inline", "crossline", "depth"] + + def test_extends_existing_drop_vars(self) -> None: + """The drop list should be extended, not replaced.""" + grid = make_grid([("x", np.array([1, 2], dtype=np.int32))]) + dataset = xr_Dataset({"x": xr_DataArray(np.zeros(2, dtype=np.int32), dims=["x"])}) + + _, drop_vars = populate_dim_coordinates(dataset, grid, drop_vars_delayed=["already_there"]) + + assert drop_vars == ["already_there", "x"] + + +class TestPopulateNonDimCoordinates: + """Tests for ``populate_non_dim_coordinates``.""" + + def _make_dataset_with_coord( + self, + coord_name: str, + shape: tuple[int, ...], + dims: tuple[str, ...], + encoding: dict | None, + dtype: np.dtype, + ) -> xr_Dataset: + data = xr_DataArray(np.zeros(shape, dtype=dtype), dims=list(dims)) + if encoding is not None: + data.encoding.update(encoding) + return xr_Dataset({coord_name: data}) + + def test_populates_2d_coordinate_with_scaling(self) -> None: + """Spatial coord ``cdp_x`` should be filled and scaled.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20, 30], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + # Inline-major live records → trace indices 0..5 populate the full (2, 3) grid. + live = [(1, 10), (1, 20), (1, 30), (2, 10), (2, 20), (2, 30)] + grid = make_grid_with_map( + [("inline", inline), ("crossline", crossline), ("sample", sample)], + live_records=live, + ) + + coord_values = np.array([100, 200, 300, 400, 500, 600], dtype=np.float64) + coordinates = {"cdp_x": coord_values} + + dataset = self._make_dataset_with_coord( + coord_name="cdp_x", + shape=(2, 3), + dims=("inline", "crossline"), + encoding={"_FillValue": np.float64(-1.0)}, + dtype=np.float64, + ) + + dataset, drop_vars = populate_non_dim_coordinates( + dataset, + grid, + coordinates=coordinates, + drop_vars_delayed=[], + spatial_coordinate_scalar=10, + ) + + expected = (coord_values.reshape((2, 3)) * 10).astype(np.float64) + np.testing.assert_array_equal(dataset["cdp_x"].values, expected) + assert drop_vars == ["cdp_x"] + assert coordinates == {} + + def test_uses_fill_value_for_dead_traces(self) -> None: + """Cells without a live trace should keep the configured fill value.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + # Only 3 of 4 cells are live; (inline=1, crossline=20) is dead. + live = [(1, 10), (2, 10), (2, 20)] + grid = make_grid_with_map( + [("inline", inline), ("crossline", crossline), ("sample", sample)], + live_records=live, + ) + + coord_values = np.array([100.0, 200.0, 300.0], dtype=np.float64) + dataset = self._make_dataset_with_coord( + coord_name="cdp_x", + shape=(2, 2), + dims=("inline", "crossline"), + encoding={"_FillValue": np.float64(-9999.0)}, + dtype=np.float64, + ) + + dataset, _ = populate_non_dim_coordinates( + dataset, + grid, + coordinates={"cdp_x": coord_values}, + drop_vars_delayed=[], + spatial_coordinate_scalar=1, + ) + + expected = np.array([[100.0, -9999.0], [200.0, 300.0]], dtype=np.float64) + np.testing.assert_array_equal(dataset["cdp_x"].values, expected) + + def test_non_spatial_coordinate_not_scaled(self) -> None: + """Non-spatial coords (e.g. offset) must not be touched by coord scalar.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + live = [(1, 10), (1, 20), (2, 10), (2, 20)] + grid = make_grid_with_map( + [("inline", inline), ("crossline", crossline), ("sample", sample)], + live_records=live, + ) + + coord_values = np.array([5, 6, 7, 8], dtype=np.float64) + dataset = self._make_dataset_with_coord( + coord_name="not_spatial", + shape=(2, 2), + dims=("inline", "crossline"), + encoding={"_FillValue": np.float64(0.0)}, + dtype=np.float64, + ) + + dataset, _ = populate_non_dim_coordinates( + dataset, + grid, + coordinates={"not_spatial": coord_values}, + drop_vars_delayed=[], + spatial_coordinate_scalar=100, # would change values if applied + ) + + np.testing.assert_array_equal(dataset["not_spatial"].values, coord_values.reshape((2, 2))) + + def test_reduced_coordinate_uses_slice(self) -> None: + """A coord declared on a subset of dims should be filled via a sliced map.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20, 30], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + live = [(1, 10), (1, 20), (1, 30), (2, 10), (2, 20), (2, 30)] + grid = make_grid_with_map( + [("inline", inline), ("crossline", crossline), ("sample", sample)], + live_records=live, + ) + + # Trace indices along the inline=0 row are 0, 1, 2 so the coord values + # at those positions are taken from coord_values[0:3]. + coord_values = np.array([10.0, 20.0, 30.0, 40.0, 50.0, 60.0], dtype=np.float64) + dataset = self._make_dataset_with_coord( + coord_name="offset", + shape=(3,), + dims=("crossline",), + encoding={"_FillValue": np.float64(-1.0)}, + dtype=np.float64, + ) + + dataset, _ = populate_non_dim_coordinates( + dataset, + grid, + coordinates={"offset": coord_values}, + drop_vars_delayed=[], + spatial_coordinate_scalar=1, + ) + + np.testing.assert_array_equal(dataset["offset"].values, coord_values[:3]) + + def test_default_fill_value_is_nan_when_encoding_missing(self) -> None: + """When no ``_FillValue`` / ``fill_value`` is set, dead traces become NaN.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + live = [(1, 10), (2, 10), (2, 20)] + grid = make_grid_with_map( + [("inline", inline), ("crossline", crossline), ("sample", sample)], + live_records=live, + ) + + coord_values = np.array([1.5, 2.5, 3.5], dtype=np.float64) + dataset = self._make_dataset_with_coord( + coord_name="cdp_x", + shape=(2, 2), + dims=("inline", "crossline"), + encoding=None, + dtype=np.float64, + ) + + dataset, _ = populate_non_dim_coordinates( + dataset, + grid, + coordinates={"cdp_x": coord_values}, + drop_vars_delayed=[], + spatial_coordinate_scalar=1, + ) + + actual = dataset["cdp_x"].values + assert np.isnan(actual[0, 1]) + assert actual[0, 0] == pytest.approx(1.5) + assert actual[1, 0] == pytest.approx(2.5) + assert actual[1, 1] == pytest.approx(3.5) + + def test_fill_value_key_in_encoding_is_honored(self) -> None: + """The lowercase ``fill_value`` encoding key must be honored when ``_FillValue`` is absent.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + # Dead cell at (inline=1, crossline=20). + live = [(1, 10), (2, 10), (2, 20)] + grid = make_grid_with_map( + [("inline", inline), ("crossline", crossline), ("sample", sample)], + live_records=live, + ) + + coord_values = np.array([1.0, 2.0, 3.0], dtype=np.float64) + dataset = self._make_dataset_with_coord( + coord_name="cdp_x", + shape=(2, 2), + dims=("inline", "crossline"), + encoding={"fill_value": np.float64(-42.0)}, + dtype=np.float64, + ) + + dataset, _ = populate_non_dim_coordinates( + dataset, + grid, + coordinates={"cdp_x": coord_values}, + drop_vars_delayed=[], + spatial_coordinate_scalar=1, + ) + + expected = np.array([[1.0, -42.0], [2.0, 3.0]], dtype=np.float64) + np.testing.assert_array_equal(dataset["cdp_x"].values, expected) + + def test_empty_coordinates_is_noop(self) -> None: + """An empty coordinates dict should leave the dataset and drop list untouched.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + live = [(1, 10), (1, 20), (2, 10), (2, 20)] + grid = make_grid_with_map( + [("inline", inline), ("crossline", crossline), ("sample", sample)], + live_records=live, + ) + + dataset = xr_Dataset() + + dataset, drop_vars = populate_non_dim_coordinates( + dataset, + grid, + coordinates={}, + drop_vars_delayed=["pre_existing"], + spatial_coordinate_scalar=1, + ) + + assert drop_vars == ["pre_existing"] + assert len(dataset.data_vars) == 0 diff --git a/tests/unit/ingestion/test_grid_qc.py b/tests/unit/ingestion/test_grid_qc.py new file mode 100644 index 00000000..c0a8d475 --- /dev/null +++ b/tests/unit/ingestion/test_grid_qc.py @@ -0,0 +1,117 @@ +"""Tests for ingestion grid density / sparsity quality control.""" + +from __future__ import annotations + +import logging +import os +from unittest.mock import patch + +import numpy as np +import pytest + +from mdio.converters.exceptions import GridTraceSparsityError +from mdio.core.dimension import Dimension +from mdio.core.grid import Grid +from mdio.ingestion.grid_qc import grid_density_qc + + +def _make_grid(shape: tuple[int, ...]) -> Grid: + """Build a Grid with named dimensions of the given size.""" + names = [f"dim_{idx}" for idx in range(len(shape) - 1)] + ["sample"] + dims = [Dimension(coords=np.arange(size, dtype=np.int32), name=name) for name, size in zip(names, shape, strict=True)] + return Grid(dims=dims) + + +class TestGridDensityQc: + """Test cases for ``grid_density_qc``.""" + + def test_no_warning_when_dense(self, caplog: pytest.LogCaptureFixture) -> None: + """Dense grids (ratio <= warn) should not log or raise.""" + grid = _make_grid((10, 10, 100)) # 100 grid traces + with caplog.at_level(logging.WARNING): + grid_density_qc(grid, num_traces=100) + assert caplog.records == [] + + def test_warns_when_above_warn_threshold(self, caplog: pytest.LogCaptureFixture) -> None: + """Sparsity above warn but below limit logs a warning, no raise.""" + grid = _make_grid((10, 10, 100)) # 100 grid traces + # warn = 2, error = 10 (defaults); ratio of 5 sits between + with caplog.at_level(logging.WARNING, logger="mdio.ingestion.grid_qc"): + grid_density_qc(grid, num_traces=20) + + warnings = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warnings) == 1 + assert "Sparsity ratio: 5.00" in warnings[0].message + + def test_warning_message_includes_dim_summary(self, caplog: pytest.LogCaptureFixture) -> None: + """The warning body must include shape, trace counts, and a min/max line per dim. + + Pins the format so refactors that touch ``Grid.get_min`` / ``get_max`` or the + message template don't silently regress operator-facing output. + """ + grid = _make_grid((10, 10, 100)) # 100 grid traces + with caplog.at_level(logging.WARNING, logger="mdio.ingestion.grid_qc"): + grid_density_qc(grid, num_traces=20) + + message = caplog.records[0].message + assert "SEG-Y trace count: 20" in message + assert "grid trace count: 100" in message + assert "{'dim_0': 10, 'dim_1': 10, 'sample': 100}" in message + for dim_name in ("dim_0", "dim_1", "sample"): + assert f"\n{dim_name} min: 0 max:" in message + + def test_raises_when_above_limit(self) -> None: + """Sparsity above the error limit should raise.""" + grid = _make_grid((10, 10, 100)) # 100 grid traces, limit default = 10 + with pytest.raises(GridTraceSparsityError): + grid_density_qc(grid, num_traces=5) # ratio 20 > 10 + + def test_ignore_checks_suppresses_error(self, caplog: pytest.LogCaptureFixture) -> None: + """Setting MDIO_IGNORE_CHECKS still warns but never raises.""" + grid = _make_grid((10, 10, 100)) + with patch.dict(os.environ, {"MDIO_IGNORE_CHECKS": "1"}), caplog.at_level( + logging.WARNING, logger="mdio.ingestion.grid_qc" + ): + grid_density_qc(grid, num_traces=5) + + warnings = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warnings) == 1 + + def test_zero_traces_treated_as_infinite_sparsity(self) -> None: + """A SEG-Y with zero traces should be flagged via the limit branch.""" + grid = _make_grid((2, 2, 5)) + with pytest.raises(GridTraceSparsityError): + grid_density_qc(grid, num_traces=0) + + @pytest.mark.parametrize( + ("warn", "limit", "num_traces", "expect_raise", "expect_warn"), + [ + ("100", "1000", "100", False, False), # ratio 1, both safe + ("0.5", "1000", "100", False, True), # ratio 1 > 0.5 (warn only) + ("0.5", "0.9", "100", True, True), # ratio 1 > 0.9 (raise) + ], + ) + def test_thresholds_respect_env_vars( # noqa: PLR0913 + self, + warn: str, + limit: str, + num_traces: str, + expect_raise: bool, + expect_warn: bool, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Custom warn/limit env vars should drive the QC behavior.""" + grid = _make_grid((10, 10, 100)) # 100 grid traces + env = { + "MDIO__GRID__SPARSITY_RATIO_WARN": warn, + "MDIO__GRID__SPARSITY_RATIO_LIMIT": limit, + } + with patch.dict(os.environ, env), caplog.at_level(logging.WARNING, logger="mdio.ingestion.grid_qc"): + if expect_raise: + with pytest.raises(GridTraceSparsityError): + grid_density_qc(grid, num_traces=int(num_traces)) + else: + grid_density_qc(grid, num_traces=int(num_traces)) + + warned = any(r.levelno == logging.WARNING for r in caplog.records) + assert warned == expect_warn diff --git a/tests/unit/ingestion/test_metadata.py b/tests/unit/ingestion/test_metadata.py new file mode 100644 index 00000000..40d1eb81 --- /dev/null +++ b/tests/unit/ingestion/test_metadata.py @@ -0,0 +1,42 @@ +"""Tests for generic dataset metadata helpers.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from mdio.ingestion.metadata import _add_grid_override_to_metadata + + +def _make_dataset(attributes: dict | None) -> SimpleNamespace: + """Build a minimal stand-in for Dataset with a nested ``metadata.attributes``.""" + return SimpleNamespace(metadata=SimpleNamespace(attributes=attributes)) + + +class TestAddGridOverrideToMetadata: + """Tests for ``_add_grid_override_to_metadata``.""" + + def test_initializes_attributes_dict_when_none(self) -> None: + """A ``None`` attributes dict gets replaced with an empty dict before insertion.""" + dataset = _make_dataset(attributes=None) + _add_grid_override_to_metadata(dataset, grid_overrides=None) + assert dataset.metadata.attributes == {} + + def test_adds_grid_overrides_when_provided(self) -> None: + """Grid overrides should land under the ``gridOverrides`` key.""" + dataset = _make_dataset(attributes=None) + overrides = {"HasDuplicates": True, "chunksize": 4} + _add_grid_override_to_metadata(dataset, grid_overrides=overrides) + assert dataset.metadata.attributes == {"gridOverrides": overrides} + + def test_preserves_existing_attributes(self) -> None: + """Existing attribute keys should be preserved when adding overrides.""" + dataset = _make_dataset(attributes={"existing": "value"}) + overrides = {"NonBinned": True} + _add_grid_override_to_metadata(dataset, grid_overrides=overrides) + assert dataset.metadata.attributes == {"existing": "value", "gridOverrides": overrides} + + def test_no_overrides_leaves_attributes_untouched(self) -> None: + """Passing ``None`` overrides must not introduce a ``gridOverrides`` key.""" + dataset = _make_dataset(attributes={"existing": "value"}) + _add_grid_override_to_metadata(dataset, grid_overrides=None) + assert dataset.metadata.attributes == {"existing": "value"} diff --git a/tests/unit/ingestion/test_segy_coordinates.py b/tests/unit/ingestion/test_segy_coordinates.py new file mode 100644 index 00000000..06c7c6f8 --- /dev/null +++ b/tests/unit/ingestion/test_segy_coordinates.py @@ -0,0 +1,222 @@ +"""Tests for SEG-Y coordinate extraction and unit resolution helpers.""" + +from __future__ import annotations + +import logging +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pytest +from xarray import DataArray as xr_DataArray +from xarray import Dataset as xr_Dataset + +from mdio.builder.schemas.v1.units import AngleUnitEnum +from mdio.builder.schemas.v1.units import AngleUnitModel +from mdio.builder.schemas.v1.units import LengthUnitEnum +from mdio.builder.schemas.v1.units import LengthUnitModel +from mdio.builder.templates.base import AbstractDatasetTemplate +from mdio.builder.templates.seismic_3d_poststack import Seismic3DPostStackTemplate +from mdio.ingestion.segy.coordinates import _get_coordinates +from mdio.ingestion.segy.coordinates import _get_spatial_coordinate_unit +from mdio.ingestion.segy.coordinates import _populate_coordinates +from mdio.ingestion.segy.coordinates import _update_template_units +from tests.unit.ingestion.testing_helpers import make_grid +from tests.unit.ingestion.testing_helpers import make_grid_with_map +from tests.unit.ingestion.testing_helpers import make_header_array + + +class TestGetCoordinates: + """Tests for ``_get_coordinates``.""" + + def test_returns_dims_and_coords_in_template_order(self) -> None: + """Dim coords and non-dim coords should follow the template's declared order.""" + inline = np.array([1, 2, 3], dtype=np.int32) + crossline = np.array([10, 20], dtype=np.int32) + sample = np.array([0, 4, 8, 12], dtype=np.int32) + grid = make_grid([("inline", inline), ("crossline", crossline), ("time", sample)]) + + n = inline.size * crossline.size + cdp_x = np.arange(n, dtype=np.float64) + cdp_y = np.arange(n, dtype=np.float64) + 100.0 + headers = make_header_array({"cdp_x": cdp_x, "cdp_y": cdp_y}) + + template = Seismic3DPostStackTemplate(data_domain="time") + + dim_coords, non_dim = _get_coordinates(grid, headers, template) + + assert [d.name for d in dim_coords] == ["inline", "crossline", "time"] + np.testing.assert_array_equal(dim_coords[0].coords, inline) + np.testing.assert_array_equal(dim_coords[1].coords, crossline) + assert list(non_dim.keys()) == ["cdp_x", "cdp_y"] + np.testing.assert_array_equal(non_dim["cdp_x"], cdp_x) + np.testing.assert_array_equal(non_dim["cdp_y"], cdp_y) + + def test_missing_dimension_raises(self) -> None: + """A template dim missing from the grid should raise ValueError.""" + grid = make_grid( + [ + ("inline", np.array([1, 2], dtype=np.int32)), + # Missing 'crossline' + ("time", np.array([0, 4], dtype=np.int32)), + ] + ) + headers = make_header_array({"cdp_x": np.zeros(2, dtype=np.float64), "cdp_y": np.zeros(2, dtype=np.float64)}) + + template = Seismic3DPostStackTemplate(data_domain="time") + + with pytest.raises(ValueError, match=r"Dimension 'crossline' was not found"): + _get_coordinates(grid, headers, template) + + def test_missing_coordinate_field_raises(self) -> None: + """A template coord absent from SEG-Y headers should raise ValueError.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + grid = make_grid([("inline", inline), ("crossline", crossline), ("time", sample)]) + # Headers lack 'cdp_y' + headers = make_header_array({"cdp_x": np.zeros(4, dtype=np.float64)}) + + template = Seismic3DPostStackTemplate(data_domain="time") + + with pytest.raises(ValueError, match=r"Coordinate 'cdp_y' not found"): + _get_coordinates(grid, headers, template) + + +class TestPopulateCoordinates: + """Tests for the ``_populate_coordinates`` wrapper. + + These pin the contract that wraps ``populate_dim_coordinates`` + + ``populate_non_dim_coordinates``: dim names land in ``drop_vars`` before coord + names, both halves run, and the wrapper threads its own initially-empty drop + list. + """ + + def test_wraps_dim_and_non_dim_population_in_order(self) -> None: + """Wrapper should populate dims then non-dims and concatenate drop lists.""" + inline = np.array([1, 2], dtype=np.int32) + crossline = np.array([10, 20], dtype=np.int32) + sample = np.array([0, 4], dtype=np.int32) + live = [(1, 10), (1, 20), (2, 10), (2, 20)] + grid = make_grid_with_map( + [("inline", inline), ("crossline", crossline), ("sample", sample)], + live_records=live, + ) + + cdp_x = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float64) + cdp_x_da = xr_DataArray(np.zeros((2, 2), dtype=np.float64), dims=["inline", "crossline"]) + cdp_x_da.encoding["_FillValue"] = np.float64(-1.0) + + dataset = xr_Dataset( + { + "inline": xr_DataArray(np.zeros(2, dtype=np.int32), dims=["inline"]), + "crossline": xr_DataArray(np.zeros(2, dtype=np.int32), dims=["crossline"]), + "sample": xr_DataArray(np.zeros(2, dtype=np.int32), dims=["sample"]), + "cdp_x": cdp_x_da, + } + ) + + dataset, drop_vars = _populate_coordinates( + dataset, + grid, + coords={"cdp_x": cdp_x}, + spatial_coordinate_scalar=1, + ) + + np.testing.assert_array_equal(dataset["inline"].values, inline) + np.testing.assert_array_equal(dataset["crossline"].values, crossline) + np.testing.assert_array_equal(dataset["sample"].values, sample) + np.testing.assert_array_equal(dataset["cdp_x"].values, cdp_x.reshape((2, 2))) + # Dim names recorded first, then non-dim coord names. + assert drop_vars == ["inline", "crossline", "sample", "cdp_x"] + + +class TestGetSpatialCoordinateUnit: + """Tests for ``_get_spatial_coordinate_unit``.""" + + @pytest.mark.parametrize( + ("code", "expected_unit"), + [ + (1, LengthUnitEnum.METER), + (2, LengthUnitEnum.FOOT), + ], + ) + def test_known_measurement_codes(self, code: int, expected_unit: LengthUnitEnum) -> None: + """Codes 1 (m) and 2 (ft) return the corresponding length unit.""" + info = SimpleNamespace(binary_header_dict={"measurement_system_code": code}) + result = _get_spatial_coordinate_unit(info) + assert isinstance(result, LengthUnitModel) + assert result.length == expected_unit + + def test_unknown_code_returns_none_and_warns(self, caplog: pytest.LogCaptureFixture) -> None: + """Unexpected codes should log a warning and return ``None``.""" + info = SimpleNamespace(binary_header_dict={"measurement_system_code": 7}) + with caplog.at_level(logging.WARNING, logger="mdio.ingestion.segy.coordinates"): + result = _get_spatial_coordinate_unit(info) + assert result is None + assert any("Unexpected value in coordinate unit" in r.message for r in caplog.records) + + +class TestUpdateTemplateUnits: + """Tests for ``_update_template_units``.""" + + def _stub_template(self) -> MagicMock: + template = MagicMock(spec=AbstractDatasetTemplate) + template.get_unit_by_key.return_value = None + return template + + def test_adds_angle_units_only_when_spatial_unit_missing(self) -> None: + """Without a spatial unit, only angle units should be added.""" + template = self._stub_template() + result = _update_template_units(template, unit=None) + + template.add_units.assert_called_once() + added = template.add_units.call_args.args[0] + assert set(added.keys()) == {"angle", "azimuth"} + for unit in added.values(): + assert isinstance(unit, AngleUnitModel) + assert unit.angle == AngleUnitEnum.DEGREES + assert result is template + + def test_adds_spatial_units_when_unit_provided(self) -> None: + """A non-None unit should populate all SPATIAL keys plus angle keys.""" + template = self._stub_template() + unit = LengthUnitModel(length=LengthUnitEnum.METER) + + _update_template_units(template, unit=unit) + + added = template.add_units.call_args.args[0] + expected_keys = { + "angle", + "azimuth", + "cdp_x", + "cdp_y", + "source_coord_x", + "source_coord_y", + "group_coord_x", + "group_coord_y", + "offset", + } + assert set(added.keys()) == expected_keys + for key in {"cdp_x", "cdp_y", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y", "offset"}: + assert added[key] is unit + + def test_preserves_pre_existing_spatial_units(self, caplog: pytest.LogCaptureFixture) -> None: + """Keys that already have a template unit must not be overwritten.""" + existing = LengthUnitModel(length=LengthUnitEnum.FOOT) + new_unit = LengthUnitModel(length=LengthUnitEnum.METER) + + template = MagicMock(spec=AbstractDatasetTemplate) + + def fake_lookup(key: str) -> LengthUnitModel | None: + return existing if key == "cdp_x" else None + + template.get_unit_by_key.side_effect = fake_lookup + + with caplog.at_level(logging.WARNING, logger="mdio.ingestion.segy.coordinates"): + _update_template_units(template, unit=new_unit) + + added = template.add_units.call_args.args[0] + assert "cdp_x" not in added + assert added["cdp_y"] is new_unit + assert any("already in template" in r.message for r in caplog.records) diff --git a/tests/unit/ingestion/test_segy_file_headers.py b/tests/unit/ingestion/test_segy_file_headers.py new file mode 100644 index 00000000..b831ed15 --- /dev/null +++ b/tests/unit/ingestion/test_segy_file_headers.py @@ -0,0 +1,94 @@ +"""Tests for attaching SEG-Y text/binary file headers to xarray datasets.""" + +from __future__ import annotations + +import base64 +import os +from types import SimpleNamespace +from unittest.mock import patch + +import numpy as np +import pytest +from xarray import DataArray as xr_DataArray +from xarray import Dataset as xr_Dataset + +from mdio.ingestion.segy.file_headers import _add_segy_file_headers + + +def _valid_text_header() -> str: + """Build a SEG-Y text header with 40 rows of exactly 80 chars.""" + return "\n".join(["X" * 80] * 40) + + +def _make_segy_info( + text_header: str | None = None, + binary_header_dict: dict | None = None, + raw_binary_headers: bytes = b"\x00" * 400, +) -> SimpleNamespace: + return SimpleNamespace( + text_header=text_header if text_header is not None else _valid_text_header(), + binary_header_dict=binary_header_dict if binary_header_dict is not None else {"sample_interval": 4000}, + raw_binary_headers=raw_binary_headers, + ) + + +def _empty_dataset() -> xr_Dataset: + return xr_Dataset({"amplitude": xr_DataArray(np.zeros(2, dtype=np.float32), dims=["inline"])}) + + +class TestAddSegyFileHeaders: + """Tests for ``_add_segy_file_headers``.""" + + def test_disabled_returns_dataset_unchanged(self) -> None: + """When the save flag is off the dataset must not be modified.""" + info = _make_segy_info() + ds = _empty_dataset() + with patch.dict(os.environ, {"MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "false"}): + result = _add_segy_file_headers(ds, info) + + assert "segy_file_header" not in result + + def test_attaches_headers_when_enabled(self) -> None: + """Enabling the flag should attach text + binary header attrs.""" + info = _make_segy_info() + ds = _empty_dataset() + with patch.dict(os.environ, {"MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "true"}): + result = _add_segy_file_headers(ds, info) + + attrs = result["segy_file_header"].attrs + assert attrs["textHeader"] == info.text_header + assert attrs["binaryHeader"] == info.binary_header_dict + assert "rawBinaryHeader" not in attrs + + def test_attaches_raw_binary_when_raw_flag_enabled(self) -> None: + """``raw_headers`` should add the base64-encoded raw binary headers.""" + info = _make_segy_info(raw_binary_headers=b"abc") + ds = _empty_dataset() + env = { + "MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "true", + "MDIO__IMPORT__RAW_HEADERS": "true", + } + with patch.dict(os.environ, env): + result = _add_segy_file_headers(ds, info) + + encoded = base64.b64encode(b"abc").decode("ascii") + assert result["segy_file_header"].attrs["rawBinaryHeader"] == encoded + + def test_invalid_row_count_raises(self) -> None: + """Text header without 40 rows must raise.""" + bad_text = "\n".join(["X" * 80] * 39) + info = _make_segy_info(text_header=bad_text) + ds = _empty_dataset() + with patch.dict(os.environ, {"MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "true"}): + with pytest.raises(ValueError, match="Invalid text header count"): + _add_segy_file_headers(ds, info) + + def test_invalid_column_count_raises(self) -> None: + """Text header rows shorter than 80 chars must raise.""" + bad_rows = ["X" * 80] * 40 + bad_rows[5] = "X" * 79 + info = _make_segy_info(text_header="\n".join(bad_rows)) + ds = _empty_dataset() + with patch.dict(os.environ, {"MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "true"}): + with pytest.raises(ValueError, match="Invalid text header columns"): + _add_segy_file_headers(ds, info) diff --git a/tests/unit/ingestion/test_segy_header_analysis.py b/tests/unit/ingestion/test_segy_header_analysis.py new file mode 100644 index 00000000..1f1a7ff7 --- /dev/null +++ b/tests/unit/ingestion/test_segy_header_analysis.py @@ -0,0 +1,242 @@ +"""Tests for SEG-Y header analysis acquisition-geometry helpers.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from mdio.ingestion.segy.header_analysis import ShotGunGeometryType +from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType +from mdio.ingestion.segy.header_analysis import analyze_lines_for_guns +from mdio.ingestion.segy.header_analysis import analyze_non_indexed_headers +from mdio.ingestion.segy.header_analysis import analyze_saillines_for_guns +from mdio.ingestion.segy.header_analysis import analyze_streamer_headers +from mdio.ingestion.segy.header_analysis import create_counter +from mdio.ingestion.segy.header_analysis import create_trace_index +from tests.unit.ingestion.testing_helpers import make_header_array + + +def _streamer_headers(records: list[tuple[int, int]]) -> np.ndarray: + """Build a (cable, channel) header array from ``(cable, channel)`` pairs.""" + cables, channels = zip(*records, strict=True) + return make_header_array( + { + "cable": np.asarray(cables, dtype=np.int32), + "channel": np.asarray(channels, dtype=np.int32), + } + ) + + +def _gun_headers(records: list[tuple[int, int, int]], line_field: str = "sail_line") -> np.ndarray: + """Build a (line_field, shot_point, gun) header array from triples.""" + lines, shots, guns = zip(*records, strict=True) + return make_header_array( + { + line_field: np.asarray(lines, dtype=np.int32), + "shot_point": np.asarray(shots, dtype=np.int32), + "gun": np.asarray(guns, dtype=np.int8), + } + ) + + +class TestAnalyzeStreamerHeaders: + """Tests for ``analyze_streamer_headers``.""" + + def test_non_overlapping_channels_returns_type_b(self) -> None: + """Non-overlapping cable channel ranges should produce Configuration B.""" + records: list[tuple[int, int]] = [] + for cable in (1, 2, 3): + for chan in range(1, 6): + records.append((cable, (cable - 1) * 5 + chan)) + + headers = _streamer_headers(records) + + unique_cables, mins, maxs, geom = analyze_streamer_headers(headers) + + np.testing.assert_array_equal(unique_cables, [1, 2, 3]) + np.testing.assert_array_equal(mins, [1, 6, 11]) + np.testing.assert_array_equal(maxs, [5, 10, 15]) + assert geom is StreamerShotGeometryType.B + + def test_overlapping_channels_returns_type_a(self) -> None: + """Overlapping channel ranges between cables should produce Configuration A.""" + records: list[tuple[int, int]] = [] + for cable in (1, 2): + for chan in range(1, 6): + records.append((cable, chan)) + + headers = _streamer_headers(records) + unique_cables, _, _, geom = analyze_streamer_headers(headers) + + np.testing.assert_array_equal(unique_cables, [1, 2]) + assert geom is StreamerShotGeometryType.A + + def test_single_cable_returns_type_b(self) -> None: + """A single cable has no neighbours to overlap with → Configuration B.""" + records = [(7, chan) for chan in range(1, 6)] + headers = _streamer_headers(records) + + unique_cables, mins, maxs, geom = analyze_streamer_headers(headers) + + np.testing.assert_array_equal(unique_cables, [7]) + np.testing.assert_array_equal(mins, [1]) + np.testing.assert_array_equal(maxs, [5]) + assert geom is StreamerShotGeometryType.B + + +class TestAnalyzeLinesForGuns: + """Tests for ``analyze_lines_for_guns``.""" + + def test_dense_shots_per_gun_returns_type_a(self) -> None: + """Each gun densely numbered 1..N (overlap across guns) -> Configuration A.""" + # Gun 1: shots 1..4, Gun 2: shots 1..4 (line value 100) + records = [(100, shot, gun) for gun in (1, 2) for shot in range(1, 5)] + headers = _gun_headers(records) + + unique_lines, per_line, geom = analyze_lines_for_guns(headers, line_field="sail_line") + + np.testing.assert_array_equal(unique_lines, [100]) + assert per_line == {"100": [1, 2]} + assert geom is ShotGunGeometryType.A + + def test_interleaved_shots_returns_type_b(self) -> None: + """Interleaved shot numbering (unique per line, sparse per gun) -> Configuration B.""" + # Gun 1: odd shots, gun 2: even shots, all unique within the same line. + records = [] + for shot in (1, 3, 5): + records.append((200, shot, 1)) + for shot in (2, 4, 6): + records.append((200, shot, 2)) + headers = _gun_headers(records) + + unique_lines, per_line, geom = analyze_lines_for_guns(headers, line_field="sail_line") + + np.testing.assert_array_equal(unique_lines, [200]) + assert per_line == {"200": [1, 2]} + assert geom is ShotGunGeometryType.B + + def test_custom_line_field(self) -> None: + """Function must work for any line-field name (e.g. ``shot_line``).""" + records = [(7, shot, gun) for gun in (1, 2) for shot in range(1, 4)] + headers = _gun_headers(records, line_field="shot_line") + + unique_lines, per_line, geom = analyze_lines_for_guns(headers, line_field="shot_line") + + np.testing.assert_array_equal(unique_lines, [7]) + assert per_line == {"7": [1, 2]} + # Dense 1..N per gun → configuration A + assert geom is ShotGunGeometryType.A + + def test_single_gun_line_stays_type_b(self) -> None: + """A line with a single gun has nothing to interleave → Configuration B. + + With ``num_guns == 1``, dividing shot points by 1 is the identity, so the + floor/unique check trivially matches and the function should stay in B. + """ + records = [(300, shot, 1) for shot in range(1, 5)] + headers = _gun_headers(records) + + unique_lines, per_line, geom = analyze_lines_for_guns(headers, line_field="sail_line") + + np.testing.assert_array_equal(unique_lines, [300]) + assert per_line == {"300": [1]} + assert geom is ShotGunGeometryType.B + + +class TestAnalyzeSailLinesForGuns: + """Tests for ``analyze_saillines_for_guns``.""" + + def test_delegates_to_generic_function(self) -> None: + """The sail-line variant should mirror ``analyze_lines_for_guns`` with ``sail_line``.""" + records = [(11, shot, gun) for gun in (1, 2) for shot in range(1, 4)] + headers = _gun_headers(records, line_field="sail_line") + + lines_a, per_line_a, geom_a = analyze_saillines_for_guns(headers) + lines_b, per_line_b, geom_b = analyze_lines_for_guns(headers, line_field="sail_line") + + np.testing.assert_array_equal(lines_a, lines_b) + assert per_line_a == per_line_b + assert geom_a is geom_b + # Sanity check on the underlying geometry (dense per gun -> A). + assert geom_a is ShotGunGeometryType.A + + +class TestCreateCounter: + """Tests for ``create_counter``.""" + + def test_returns_zero_at_max_depth(self) -> None: + """Reaching the requested total depth yields the leaf integer 0.""" + assert create_counter(2, 2, {}, []) == 0 + + def test_builds_nested_dict_per_header(self) -> None: + """Two-level tree should have a leaf 0 under each combination.""" + unique = {"cable": np.array([1, 2]), "channel": np.array([10, 11, 12])} + names = ["cable", "channel"] + + tree = create_counter(0, 2, unique, names) + + assert set(tree.keys()) == {1, 2} + for cable in (1, 2): + assert set(tree[cable].keys()) == {10, 11, 12} + for chan in (10, 11, 12): + assert tree[cable][chan] == 0 + + +class TestCreateTraceIndex: + """Tests for ``create_trace_index``.""" + + def test_returns_none_when_depth_zero(self) -> None: + """A zero-depth tree means no header names → None.""" + headers = np.array([], dtype=[("cable", "i4")]) + assert create_trace_index(0, {}, headers, []) is None + + def test_assigns_dense_trace_index_for_single_header(self) -> None: + """One-dim counter assigns 1..N within each unique header value.""" + cable = np.array([1, 1, 2, 2, 2], dtype=np.int32) + headers = np.empty(cable.size, dtype=[("cable", "i4")]) + headers["cable"] = cable + counter = {1: 0, 2: 0} + + out = create_trace_index(1, counter, headers, ["cable"]) + + assert out is not None + assert "trace" in out.dtype.names + np.testing.assert_array_equal(out["trace"], [1, 2, 1, 2, 3]) + + def test_assigns_dense_trace_index_for_two_headers(self) -> None: + """Two-dim counter assigns 1..N within each unique (cable, channel).""" + cable = np.array([1, 1, 2, 2, 2], dtype=np.int32) + channel = np.array([10, 10, 20, 20, 30], dtype=np.int32) + headers = np.empty(cable.size, dtype=[("cable", "i4"), ("channel", "i4")]) + headers["cable"] = cable + headers["channel"] = channel + counter = {1: {10: 0, 20: 0, 30: 0}, 2: {10: 0, 20: 0, 30: 0}} + + out = create_trace_index(2, counter, headers, ["cable", "channel"]) + + np.testing.assert_array_equal(out["trace"], [1, 2, 1, 2, 1]) + + +class TestAnalyzeNonIndexedHeaders: + """Tests for ``analyze_non_indexed_headers``.""" + + def test_adds_trace_field_with_dense_index(self) -> None: + """The returned header array should carry a 'trace' field counting within unique keys.""" + cable = np.array([1, 1, 1, 2, 2], dtype=np.int32) + headers = np.empty(cable.size, dtype=[("cable", "i4")]) + headers["cable"] = cable + + out = analyze_non_indexed_headers(headers) + + assert "trace" in out.dtype.names + np.testing.assert_array_equal(out["trace"], [1, 2, 3, 1, 2]) + + @pytest.mark.parametrize("dtype", [np.int16, np.int32, np.int64]) + def test_respects_dtype_kwarg(self, dtype: type[np.integer]) -> None: + """The dtype kwarg should drive the 'trace' field's numpy dtype.""" + cable = np.array([1, 2], dtype=np.int32) + headers = np.empty(cable.size, dtype=[("cable", "i4")]) + headers["cable"] = cable + + out = analyze_non_indexed_headers(headers, dtype=dtype) + assert out["trace"].dtype == np.dtype(dtype) diff --git a/tests/unit/ingestion/test_segy_validation.py b/tests/unit/ingestion/test_segy_validation.py new file mode 100644 index 00000000..076c6db2 --- /dev/null +++ b/tests/unit/ingestion/test_segy_validation.py @@ -0,0 +1,104 @@ +"""Tests for SEG-Y spec/template validation (canonical ingestion path).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from segy.schema import HeaderField +from segy.standards import get_segy_standard + +from mdio.builder.template_registry import TemplateRegistry +from mdio.builder.templates.base import AbstractDatasetTemplate +from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate +from mdio.ingestion.segy.validation import _validate_spec_in_template + + +class TestValidateSpecInTemplate: + """Direct tests for the canonical ``mdio.ingestion.segy.validation`` module.""" + + def test_passes_with_all_required_fields(self) -> None: + """All declared dim/coord fields present → no error.""" + template = MagicMock(spec=AbstractDatasetTemplate) + template.spatial_dimension_names = ("inline", "crossline") + template.coordinate_names = ("cdp_x", "cdp_y") + template.calculated_dimension_names = () + + segy_spec = get_segy_standard(1.0) + _validate_spec_in_template(segy_spec, template) + + def test_missing_fields_listed_in_error(self) -> None: + """The error message must enumerate all missing required fields.""" + template = MagicMock(spec=AbstractDatasetTemplate) + template.name = "CustomTemplate" + template.spatial_dimension_names = ("custom_dim1", "custom_dim2") + template.coordinate_names = ("custom_coord_x",) + template.calculated_dimension_names = () + + spec = get_segy_standard(1.0) + # Only one of the custom dims is present + spec = spec.customize(trace_header_fields=[HeaderField(name="custom_dim1", byte=189, format="int32")]) + + with pytest.raises(ValueError, match=r"Required fields.*not found in.*segy_spec") as exc: + _validate_spec_in_template(spec, template) + + msg = str(exc.value) + assert "custom_dim2" in msg + assert "custom_coord_x" in msg + assert "CustomTemplate" in msg + + def test_missing_coordinate_scalar_raises(self) -> None: + """A spec without ``coordinate_scalar`` must always fail.""" + template = MagicMock(spec=AbstractDatasetTemplate) + template.name = "TestTemplate" + template.spatial_dimension_names = ("inline", "crossline") + template.coordinate_names = ("cdp_x", "cdp_y") + template.calculated_dimension_names = () + + spec = get_segy_standard(1.0) + kept = [f for f in spec.trace.header.fields if f.name != "coordinate_scalar"] + kept.append(HeaderField(name="not_coordinate_scalar", byte=71, format="int16")) + spec = spec.customize(trace_header_fields=kept) + + with pytest.raises(ValueError, match=r"coordinate_scalar"): + _validate_spec_in_template(spec, template) + + def test_calculated_dimensions_are_not_required(self) -> None: + """Dimensions in ``calculated_dimension_names`` should not be required from the spec.""" + template = MagicMock(spec=AbstractDatasetTemplate) + template.name = "CalcDim" + template.spatial_dimension_names = ("inline", "crossline", "calculated_only") + template.coordinate_names = ("cdp_x", "cdp_y") + template.calculated_dimension_names = ("calculated_only",) + + segy_spec = get_segy_standard(1.0) + _validate_spec_in_template(segy_spec, template) + + def test_obn_template_excludes_component_requirement(self) -> None: + """OBN templates synthesize ``component`` when absent → not required from spec.""" + template = Seismic3DObnReceiverGathersTemplate(data_domain="time") + # Make sure the registry has it (registry use is independent of validation). + assert TemplateRegistry().get("ObnReceiverGathers3D") is not None + + spec = get_segy_standard(1.0) + # Add all required OBN fields except 'component'. + required = ( + set(template.spatial_dimension_names) | set(template.coordinate_names) + ) - set(template.calculated_dimension_names) + required.discard("component") + + extra = [HeaderField(name=name, byte=189, format="int32") for name in sorted(required)] + # Spread bytes so they don't collide. + spec = spec.customize( + trace_header_fields=[HeaderField(name=f.name, byte=189 + idx * 4, format="int32") for idx, f in enumerate(extra)] + ) + + _validate_spec_in_template(spec, template) + + def test_obn_template_missing_other_required_field_still_fails(self) -> None: + """Even with the ``component`` carve-out, other missing fields should error.""" + template = Seismic3DObnReceiverGathersTemplate(data_domain="time") + spec = get_segy_standard(1.0) # missing OBN-specific fields like 'receiver', 'shot_line', etc. + + with pytest.raises(ValueError, match=r"Required fields.*not found"): + _validate_spec_in_template(spec, template) diff --git a/tests/unit/ingestion/testing_helpers.py b/tests/unit/ingestion/testing_helpers.py new file mode 100644 index 00000000..ea01fe99 --- /dev/null +++ b/tests/unit/ingestion/testing_helpers.py @@ -0,0 +1,73 @@ +"""Shared builders for ingestion unit tests.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from mdio.core.dimension import Dimension +from mdio.core.grid import Grid + +if TYPE_CHECKING: + from collections.abc import Sequence + + from numpy.typing import NDArray + + +def make_grid(dim_specs: list[tuple[str, NDArray]]) -> Grid: + """Build a Grid from a list of ``(name, coords)`` pairs.""" + dims = [Dimension(coords=coords, name=name) for name, coords in dim_specs] + return Grid(dims=dims) + + +def make_grid_with_map( + dim_specs: list[tuple[str, NDArray]], + live_records: Sequence[tuple], +) -> Grid: + """Build a Grid and populate its trace map via ``Grid.build_map``. + + The trace index for each record matches its position in ``live_records``, exactly + mirroring how production ingestion code assigns trace ordinals when streaming + SEG-Y headers through ``Grid.build_map``. + + Args: + dim_specs: Ordered ``(name, coords)`` pairs. The last entry is the vertical + (sample/depth) dimension and is excluded from the trace map per + ``Grid.build_map`` conventions. + live_records: Per-trace tuples giving the value of each non-sample + dimension, in dimension order. Cells absent from this list remain at + the map's fill value. + + Returns: + Grid with a real Zarr-backed ``map`` and ``live_mask`` populated. + """ + grid = make_grid(dim_specs) + non_sample_dims = grid.dims[:-1] + names = [d.name for d in non_sample_dims] + formats = [np.asarray(d.coords).dtype for d in non_sample_dims] + header_dtype = np.dtype({"names": names, "formats": formats}) + headers = np.empty(len(live_records), dtype=header_dtype) + for idx, values in enumerate(live_records): + for name, value in zip(names, values, strict=True): + headers[name][idx] = value + grid.build_map(headers) + return grid + + +def make_header_array(field_values: dict[str, NDArray]) -> NDArray: + """Build a structured numpy array mimicking a SEG-Y HeaderArray. + + Args: + field_values: Mapping of field name to a 1-D array of values. All arrays must + share the same shape. + + Returns: + Structured ``ndarray`` with one named column per field. + """ + sample = next(iter(field_values.values())) + dtype = np.dtype({"names": list(field_values), "formats": [v.dtype for v in field_values.values()]}) + arr = np.empty(sample.shape, dtype=dtype) + for name, values in field_values.items(): + arr[name] = values + return arr From d7b638b2d01fbcb5975d111cc8f1eddcfe7ffdb0 Mon Sep 17 00:00:00 2001 From: BrianMichell Date: Mon, 18 May 2026 20:33:43 +0000 Subject: [PATCH 5/6] Precommit --- tests/unit/ingestion/test_grid_qc.py | 9 ++++++--- tests/unit/ingestion/test_segy_coordinates.py | 2 +- tests/unit/ingestion/test_segy_validation.py | 10 ++++++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/unit/ingestion/test_grid_qc.py b/tests/unit/ingestion/test_grid_qc.py index c0a8d475..ca8fbb27 100644 --- a/tests/unit/ingestion/test_grid_qc.py +++ b/tests/unit/ingestion/test_grid_qc.py @@ -18,7 +18,9 @@ def _make_grid(shape: tuple[int, ...]) -> Grid: """Build a Grid with named dimensions of the given size.""" names = [f"dim_{idx}" for idx in range(len(shape) - 1)] + ["sample"] - dims = [Dimension(coords=np.arange(size, dtype=np.int32), name=name) for name, size in zip(names, shape, strict=True)] + dims = [ + Dimension(coords=np.arange(size, dtype=np.int32), name=name) for name, size in zip(names, shape, strict=True) + ] return Grid(dims=dims) @@ -69,8 +71,9 @@ def test_raises_when_above_limit(self) -> None: def test_ignore_checks_suppresses_error(self, caplog: pytest.LogCaptureFixture) -> None: """Setting MDIO_IGNORE_CHECKS still warns but never raises.""" grid = _make_grid((10, 10, 100)) - with patch.dict(os.environ, {"MDIO_IGNORE_CHECKS": "1"}), caplog.at_level( - logging.WARNING, logger="mdio.ingestion.grid_qc" + with ( + patch.dict(os.environ, {"MDIO_IGNORE_CHECKS": "1"}), + caplog.at_level(logging.WARNING, logger="mdio.ingestion.grid_qc"), ): grid_density_qc(grid, num_traces=5) diff --git a/tests/unit/ingestion/test_segy_coordinates.py b/tests/unit/ingestion/test_segy_coordinates.py index 06c7c6f8..a857440f 100644 --- a/tests/unit/ingestion/test_segy_coordinates.py +++ b/tests/unit/ingestion/test_segy_coordinates.py @@ -198,7 +198,7 @@ def test_adds_spatial_units_when_unit_provided(self) -> None: "offset", } assert set(added.keys()) == expected_keys - for key in {"cdp_x", "cdp_y", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y", "offset"}: + for key in ("cdp_x", "cdp_y", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y", "offset"): assert added[key] is unit def test_preserves_pre_existing_spatial_units(self, caplog: pytest.LogCaptureFixture) -> None: diff --git a/tests/unit/ingestion/test_segy_validation.py b/tests/unit/ingestion/test_segy_validation.py index 076c6db2..38f1b75f 100644 --- a/tests/unit/ingestion/test_segy_validation.py +++ b/tests/unit/ingestion/test_segy_validation.py @@ -82,15 +82,17 @@ def test_obn_template_excludes_component_requirement(self) -> None: spec = get_segy_standard(1.0) # Add all required OBN fields except 'component'. - required = ( - set(template.spatial_dimension_names) | set(template.coordinate_names) - ) - set(template.calculated_dimension_names) + required = (set(template.spatial_dimension_names) | set(template.coordinate_names)) - set( + template.calculated_dimension_names + ) required.discard("component") extra = [HeaderField(name=name, byte=189, format="int32") for name in sorted(required)] # Spread bytes so they don't collide. spec = spec.customize( - trace_header_fields=[HeaderField(name=f.name, byte=189 + idx * 4, format="int32") for idx, f in enumerate(extra)] + trace_header_fields=[ + HeaderField(name=f.name, byte=189 + idx * 4, format="int32") for idx, f in enumerate(extra) + ] ) _validate_spec_in_template(spec, template) From 8b453c287ae9f680ccf7f5c7b24bda9334dd7415 Mon Sep 17 00:00:00 2001 From: BrianMichell Date: Mon, 18 May 2026 20:35:39 +0000 Subject: [PATCH 6/6] pre-commit --- .../unit/ingestion/test_segy_file_headers.py | 16 ++++++++++------ .../ingestion/test_segy_header_analysis.py | 19 ++++--------------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/tests/unit/ingestion/test_segy_file_headers.py b/tests/unit/ingestion/test_segy_file_headers.py index b831ed15..6a76d871 100644 --- a/tests/unit/ingestion/test_segy_file_headers.py +++ b/tests/unit/ingestion/test_segy_file_headers.py @@ -79,9 +79,11 @@ def test_invalid_row_count_raises(self) -> None: bad_text = "\n".join(["X" * 80] * 39) info = _make_segy_info(text_header=bad_text) ds = _empty_dataset() - with patch.dict(os.environ, {"MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "true"}): - with pytest.raises(ValueError, match="Invalid text header count"): - _add_segy_file_headers(ds, info) + with ( + patch.dict(os.environ, {"MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "true"}), + pytest.raises(ValueError, match="Invalid text header count"), + ): + _add_segy_file_headers(ds, info) def test_invalid_column_count_raises(self) -> None: """Text header rows shorter than 80 chars must raise.""" @@ -89,6 +91,8 @@ def test_invalid_column_count_raises(self) -> None: bad_rows[5] = "X" * 79 info = _make_segy_info(text_header="\n".join(bad_rows)) ds = _empty_dataset() - with patch.dict(os.environ, {"MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "true"}): - with pytest.raises(ValueError, match="Invalid text header columns"): - _add_segy_file_headers(ds, info) + with ( + patch.dict(os.environ, {"MDIO__IMPORT__SAVE_SEGY_FILE_HEADER": "true"}), + pytest.raises(ValueError, match="Invalid text header columns"), + ): + _add_segy_file_headers(ds, info) diff --git a/tests/unit/ingestion/test_segy_header_analysis.py b/tests/unit/ingestion/test_segy_header_analysis.py index 1f1a7ff7..92fdf616 100644 --- a/tests/unit/ingestion/test_segy_header_analysis.py +++ b/tests/unit/ingestion/test_segy_header_analysis.py @@ -44,11 +44,7 @@ class TestAnalyzeStreamerHeaders: def test_non_overlapping_channels_returns_type_b(self) -> None: """Non-overlapping cable channel ranges should produce Configuration B.""" - records: list[tuple[int, int]] = [] - for cable in (1, 2, 3): - for chan in range(1, 6): - records.append((cable, (cable - 1) * 5 + chan)) - + records = [(cable, (cable - 1) * 5 + chan) for cable in (1, 2, 3) for chan in range(1, 6)] headers = _streamer_headers(records) unique_cables, mins, maxs, geom = analyze_streamer_headers(headers) @@ -60,11 +56,7 @@ def test_non_overlapping_channels_returns_type_b(self) -> None: def test_overlapping_channels_returns_type_a(self) -> None: """Overlapping channel ranges between cables should produce Configuration A.""" - records: list[tuple[int, int]] = [] - for cable in (1, 2): - for chan in range(1, 6): - records.append((cable, chan)) - + records = [(cable, chan) for cable in (1, 2) for chan in range(1, 6)] headers = _streamer_headers(records) unique_cables, _, _, geom = analyze_streamer_headers(headers) @@ -102,11 +94,8 @@ def test_dense_shots_per_gun_returns_type_a(self) -> None: def test_interleaved_shots_returns_type_b(self) -> None: """Interleaved shot numbering (unique per line, sparse per gun) -> Configuration B.""" # Gun 1: odd shots, gun 2: even shots, all unique within the same line. - records = [] - for shot in (1, 3, 5): - records.append((200, shot, 1)) - for shot in (2, 4, 6): - records.append((200, shot, 2)) + records = [(200, shot, 1) for shot in (1, 3, 5)] + records.extend((200, shot, 2) for shot in (2, 4, 6)) headers = _gun_headers(records) unique_lines, per_line, geom = analyze_lines_for_guns(headers, line_field="sail_line")