diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..c788940e --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,53 @@ +name: Tests + +# Runs on every PR and on pushes to main. +on: + push: + branches: [main] + pull_request: + workflow_dispatch: + +jobs: + unit-tests: + name: Unit tests (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false # run both Python versions even if one fails, so we see both + matrix: + # 3.11+ only: the package uses `match` and `int | float` unions (3.10+) + # and requires-python is >=3.11. Older versions fail on import. + python-version: ["3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + cache-dependency-path: package/pyproject.toml + + # The workflow lives at repo root, but the package is in ./package, + # so every step below runs from there (matches local dev from package/). + - name: Install package with test dependencies + working-directory: ./package + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + # --check / --check-only report only; they never edit. CI fails if code + # was committed unformatted. Mirrors the local pre-commit checks. + - name: Check formatting (Black) + working-directory: ./package + run: black --check tests/ + + - name: Check import order (isort) + working-directory: ./package + run: isort --check-only tests/ + + # -m "not integration": runs everything except tests marked @pytest.mark.integration. + # (no integration tests yet) + - name: Run unit tests + working-directory: ./package + run: pytest -v -m "not integration" \ No newline at end of file diff --git a/package/pyproject.toml b/package/pyproject.toml index ea50b37d..6efe5ce1 100644 --- a/package/pyproject.toml +++ b/package/pyproject.toml @@ -5,7 +5,7 @@ description = "A Python package for generating methods sections in for ASL param authors = [{ name="Ibrahim Abdelazim", email="ibrahim.abdelazim@fau.de" }] license = "MIT" readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.11" dependencies = [ "PyYAML~=6.0.2", "numpy~=2.2.6", @@ -26,4 +26,37 @@ include-package-data = true where = ["src"] [tool.setuptools.package-data] -"pyaslreport" = ["*.yaml", "**/*.yaml"] \ No newline at end of file +"pyaslreport" = ["*.yaml", "**/*.yaml"] + +[project.optional-dependencies] +test = [ + "pytest>=7.0", + "pytest-cov>=4.0", + "black==26.5.1", + "isort==8.0.1", +] + +[tool.pytest.ini_options] +minversion = "7.0" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "-ra", + "--strict-markers", + "--strict-config", +] +markers = [ + "integration: integration tests requiring file fixtures (deselect with '-m \"not integration\"')", + "slow: tests that take noticeably longer to run", +] + +[tool.black] +line-length = 88 +target-version = ["py311"] + +[tool.isort] +profile = "black" +line_length = 88 +known_first_party = ["pyaslreport"] \ No newline at end of file diff --git a/package/src/pyaslreport/modalities/asl/utils.py b/package/src/pyaslreport/modalities/asl/utils.py index f23dccd4..1c55d1f9 100644 --- a/package/src/pyaslreport/modalities/asl/utils.py +++ b/package/src/pyaslreport/modalities/asl/utils.py @@ -48,7 +48,7 @@ def compare_params(params_asl, params_m0, asl_filename, m0_filename): f"Discrepancy in '{param}' for ASL file '{asl_filename}' and M0 file '{m0_filename}': " f"ASL value = {asl_value}, M0 value = {m0_value}") elif validation_type == "floatOrArray": - if isinstance(asl_value, float) and isinstance(m0_value, float): + if isinstance(asl_value, (int, float)) and isinstance(m0_value, (int, float)): difference = abs(asl_value - m0_value) difference_formatted = f"{difference:.2f}" if difference > error_variation: diff --git a/package/src/pyaslreport/tests/__init__.py b/package/tests/__init__.py similarity index 100% rename from package/src/pyaslreport/tests/__init__.py rename to package/tests/__init__.py diff --git a/package/tests/conftest.py b/package/tests/conftest.py new file mode 100644 index 00000000..4fcb386d --- /dev/null +++ b/package/tests/conftest.py @@ -0,0 +1,198 @@ +"""Shared fixtures and pytest configuration for pyaslreport tests. + +Fixtures provided: + minimal_nifti_path: Path to a tiny on-disk NIfTI file (auto-deleted after test). + make_context: Factory for building ProcessingContext with sensible defaults. + make_processor: Factory for building ASLProcessor without triggering validation. + examples_dir: Path to the integration examples directory (pytest-configurable). + minimal_asl_json: In-memory dict representing a minimal valid ASL JSON. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable + +import nibabel as nib +import numpy as np +import pytest + +from pyaslreport.modalities.asl.processor import ASLProcessor, ProcessingContext + + +# --------------------------------------------------------------------------- +# CLI option for integration test directory +# --------------------------------------------------------------------------- +def pytest_addoption(parser: pytest.Parser) -> None: + """Register the --examples-dir CLI option for the integration runner. + + Args: + parser: The pytest CLI parser, supplied by pytest at collection time. + + Notes: + Local usage: pytest --examples-dir=/path/to/examples + CI usage: omit the flag; defaults to the committed set in + tests/integration/examples. + """ + parser.addoption( + "--examples-dir", + action="store", + default=None, + help="Path to integration examples dir; falls back to the committed CI set.", + ) + + +# --------------------------------------------------------------------------- +# File-based fixtures +# --------------------------------------------------------------------------- +@pytest.fixture +def minimal_nifti_path(tmp_path: Path) -> Path: + """Create a tiny valid NIfTI file at tmp_path/asl.nii.gz. + + Args: + tmp_path: Pytest-provided per-test temporary directory. + + Returns: + Path to the written NIfTI file with shape (4, 4, 20, 2). The third + axis matches a sensible default for ProcessingContext.nifti_slice_number + so most tests do not need a custom shape. + """ + data = np.zeros((4, 4, 20, 2), dtype=np.float32) + img = nib.Nifti1Image(data, affine=np.eye(4)) + + path = tmp_path / "asl.nii.gz" + nib.save(img, str(path)) + + return path + + +@pytest.fixture +def examples_dir(request: pytest.FixtureRequest) -> Path: + """Return the path to the integration examples directory. + + Args: + request: Pytest fixture request, used to read the --examples-dir flag. + + Returns: + Resolved path to an existing examples directory. + + Raises: + pytest.skip.Exception: If the resolved path does not exist. Resolution + order is (1) the --examples-dir CLI flag, then (2) the committed + tests/integration/examples directory beside this conftest. + """ + cli_value = request.config.getoption("--examples-dir") + + if cli_value: + path = Path(cli_value).expanduser().resolve() + else: + path = Path(__file__).parent / "integration" / "examples" + + if not path.is_dir(): + pytest.skip(f"Examples directory not found: {path}") + + return path + + +# --------------------------------------------------------------------------- +# Factory fixtures +# --------------------------------------------------------------------------- +@pytest.fixture +def make_context() -> Callable[..., ProcessingContext]: + """Return a factory for building ProcessingContext with sensible defaults. + + Returns: + A callable that accepts keyword overrides and returns a fully-formed + ProcessingContext. All ten required fields receive defaults; pass + keyword arguments to override any of them. + + Example: + >>> def test_something(make_context): + ... ctx = make_context(m0_type="Separate", errors=["something"]) + ... assert ctx.m0_type == "Separate" + """ + + def _make(**overrides: Any) -> ProcessingContext: + defaults: dict[str, Any] = { + "asl_json_data": [], + "m0_prep_times_collection": [], + "errors": [], + "warnings": [], + "all_absent": True, + "bs_all_off": True, + "m0_type": None, + "global_pattern": None, + "total_acquired_pairs": None, + "nifti_slice_number": 20, + } + + defaults.update(overrides) + + return ProcessingContext(**defaults) + + return _make + + +@pytest.fixture +def make_processor(minimal_nifti_path: Path) -> Callable[..., ASLProcessor]: + """Return a factory for building ASLProcessor without input validation. + + Args: + minimal_nifti_path: Auto-injected fixture providing a real on-disk + NIfTI file. The factory uses it as the default nifti_file unless + the caller overrides it. + + Returns: + A callable that accepts keyword overrides and returns an ASLProcessor. + Useful for testing private methods like _group_files in isolation, + because BaseProcessor.__init__ stores self.data without validating. + + Example: + >>> def test_grouping(make_processor): + ... proc = make_processor(files=["/path/to/asl.json"]) + ... groups = proc._group_files("nifti") + """ + + def _make(**overrides: Any) -> ASLProcessor: + defaults: dict[str, Any] = { + "modality": "asl", + "files": [], + "dcm_files": [], + "nifti_file": str(minimal_nifti_path), + } + + defaults.update(overrides) + + return ASLProcessor(defaults) + + return _make + + +# --------------------------------------------------------------------------- +# Data fixtures (in-memory JSON dicts) +# --------------------------------------------------------------------------- +@pytest.fixture +def minimal_asl_json() -> dict[str, Any]: + """Return a minimal ASL JSON with all major-error fields valid. + + Returns: + Dictionary intended as a starting point for normalization and + validation tests. Spread it into a new dict and override individual + fields to introduce specific missing or invalid values. + """ + return { + "ArterialSpinLabelingType": "PCASL", + "MRAcquisitionType": "3D", + "PulseSequenceType": "GRASE", + "M0Type": "Separate", + "BackgroundSuppression": False, + "PostLabelingDelay": 1.8, + "LabelingDuration": 1.8, + "EchoTime": 0.012, + "RepetitionTimePreparation": 4.0, + "FlipAngle": 90, + "MagneticFieldStrength": 3, + "Manufacturer": "Siemens", + "ManufacturersModelName": "TrioTim", + "AcquisitionVoxelSize": [3, 3, 4], + } diff --git a/package/tests/test_file_grouping.py b/package/tests/test_file_grouping.py new file mode 100644 index 00000000..4f2cef74 --- /dev/null +++ b/package/tests/test_file_grouping.py @@ -0,0 +1,154 @@ +"""Tests for _group_files and FileReader behavior.""" + +import json +from pathlib import Path +from typing import Callable + +import pytest + +from pyaslreport.modalities.asl.processor import ASLProcessor, ProcessingContext + +# ---------- _group_files: NIfTI mode (exact suffix matching) ---------- + + +class TestGroupFilesNiftiMode: + def test_groups_asl_with_tsv_and_m0( + self, make_processor: Callable[..., ASLProcessor], tmp_path: Path + ) -> None: + """A canonical BIDS triple groups together correctly.""" + asl_json = tmp_path / "sub-01_asl.json" + asl_json.write_text(json.dumps({"M0Type": "Separate"})) + tsv = tmp_path / "sub-01_aslcontext.tsv" + tsv.write_text("volume_type\ncontrol\nlabel\n") + m0_json = tmp_path / "sub-01_m0scan.json" + m0_json.write_text(json.dumps({"EchoTime": 0.012})) + + proc = make_processor(files=[str(asl_json), str(tsv), str(m0_json)]) + groups = proc._group_files("nifti") + + assert len(groups) == 1 + g = groups[0] + assert g["asl_json"][0] == "sub-01_asl.json" + assert g["tsv"][0] == "sub-01_aslcontext.tsv" + assert g["m0_json"][0] == "sub-01_m0scan.json" + + def test_two_sessions_produce_two_groups( + self, make_processor: Callable[..., ASLProcessor], tmp_path: Path + ) -> None: + """Two BIDS sessions in the same directory produce two groups.""" + # Build the list explicitly. Do NOT scan tmp_path with iterdir(): the + # make_processor -> minimal_nifti_path fixture also writes asl.nii.gz + # into tmp_path, and _group_files rejects unknown extensions. + files: list[str] = [] + for i in [1, 2]: + asl_json = tmp_path / f"sub-0{i}_asl.json" + asl_json.write_text(json.dumps({"M0Type": "Separate"})) + tsv = tmp_path / f"sub-0{i}_aslcontext.tsv" + tsv.write_text("volume_type\ncontrol\nlabel\n") + files.extend([str(asl_json), str(tsv)]) + proc = make_processor(files=files) + groups = proc._group_files("nifti") + assert len(groups) == 2 + + def test_unsupported_extension_raises( + self, make_processor: Callable[..., ASLProcessor], tmp_path: Path + ) -> None: + """An unsupported extension raises ValueError during grouping.""" + bad = tmp_path / "weird.xml" + bad.write_text("") + proc = make_processor(files=[str(bad)]) + with pytest.raises(ValueError, match="Unsupported file format"): + proc._group_files("nifti") + + +# ---------- _group_files: DICOM mode (substring matching for m0) ---------- + + +class TestGroupFilesDicomMode: + def test_dicom_mode_uses_substring_for_m0( + self, make_processor: Callable[..., ASLProcessor], tmp_path: Path + ) -> None: + """In DICOM mode, any filename containing 'm0' counts as M0.""" + asl_json = tmp_path / "scan_dump.json" + asl_json.write_text(json.dumps({"M0Type": "Separate"})) + m0_json = tmp_path / "scan_m0_dump.json" + m0_json.write_text(json.dumps({"EchoTime": 0.012})) + + proc = make_processor(files=[str(asl_json), str(m0_json)]) + groups = proc._group_files("dicom") + assert len(groups) == 1 + assert groups[0]["asl_json"][0] == "scan_dump.json" + assert groups[0]["m0_json"][0] == "scan_m0_dump.json" + + +# ---------- _validate_tsv_data: missing TSV behavior ---------- + + +class TestMissingTSV: + def test_missing_tsv_in_nifti_mode_errors( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """Missing aslcontext.tsv in NIfTI mode produces a missing-file error.""" + proc = make_processor() + ctx = make_context() + group = { + "asl_json": ("asl.json", {"M0Type": "Absent"}), + "m0_json": None, + "tsv": None, + } + proc._validate_tsv_data(group, ctx, "asl.json", group["asl_json"][1], "nifti") + assert any("aslcontext.tsv" in e and "missing" in e for e in ctx.errors) + + def test_missing_tsv_in_dicom_mode_falls_through_to_dicom_repetitions( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """Missing TSV in DICOM mode delegates to _analyze_dicom_repetitions.""" + proc = make_processor() + ctx = make_context() + asl_data = {"lRepetitions": 10} + group = {"asl_json": ("asl.json", asl_data), "m0_json": None, "tsv": None} + proc._validate_tsv_data(group, ctx, "asl.json", asl_data, "dicom") + # _analyze_dicom_repetitions sets total_acquired_pairs from lRepetitions/2 + assert ctx.total_acquired_pairs == 5 + # No TSV-missing error in DICOM mode + assert not any("aslcontext.tsv" in e for e in ctx.errors) + + +# ---------- FileReader: TSV header enforcement ---------- + + +class TestFileReaderTSVHeader: + def test_valid_header_returns_data(self, tmp_path: Path) -> None: + """A 'volume_type' header with rows returns the rows as a list.""" + from pyaslreport.io.readers.file_reader import FileReader + + f = tmp_path / "valid.tsv" + f.write_text("volume_type\ncontrol\nlabel\n") + result = FileReader.read(str(f)) + assert result == ["control", "label"] + + def test_invalid_header_raises(self, tmp_path: Path) -> None: + """A header that isn't exactly 'volume_type' raises RuntimeError. + + NOTE: FileReader.read re-wraps the inner error as + 'Error reading file: Invalid TSV header, ...'. The substring match below + still matches; do NOT anchor this regex with '^'. + """ + from pyaslreport.io.readers.file_reader import FileReader + + f = tmp_path / "bad.tsv" + f.write_text("volume_types\ncontrol\nlabel\n") # plural, wrong + with pytest.raises(RuntimeError, match="Invalid TSV header"): + FileReader.read(str(f)) + + def test_empty_file_returns_none(self, tmp_path: Path) -> None: + """A truly empty TSV returns None rather than raising.""" + from pyaslreport.io.readers.file_reader import FileReader + + f = tmp_path / "empty.tsv" + f.write_text("") + assert FileReader.read(str(f)) is None diff --git a/package/tests/test_m0_tsv_validation.py b/package/tests/test_m0_tsv_validation.py new file mode 100644 index 00000000..3cf1ed85 --- /dev/null +++ b/package/tests/test_m0_tsv_validation.py @@ -0,0 +1,188 @@ +"""Tests for M0 contradiction paths, TSV validation, and BS warnings.""" + +from typing import Callable + +from pyaslreport.modalities.asl.processor import ASLProcessor, ProcessingContext + +# ---------- _validate_m0_data ---------- + + +class TestValidateM0Data: + def test_separate_with_no_m0_file_errors( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """M0Type=Separate but m0_json missing -> error appended.""" + proc = make_processor() + ctx = make_context(m0_type="Separate") + group = { + "asl_json": ("asl.json", {"M0Type": "Separate"}), + "m0_json": None, + "tsv": None, + } + proc._validate_m0_data(group, ctx, "asl.json", group["asl_json"][1]) + assert any("Separate" in e and "not provided" in e for e in ctx.errors) + + def test_absent_with_m0_file_errors( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """M0Type=Absent but m0_json present -> error appended.""" + proc = make_processor() + ctx = make_context(m0_type="Absent") + m0_data = {"EchoTime": 0.012} + group = { + "asl_json": ("asl.json", {"M0Type": "Absent"}), + "m0_json": ("m0.json", m0_data), + "tsv": None, + } + proc._validate_m0_data(group, ctx, "asl.json", group["asl_json"][1]) + assert any("Absent" in e and "is present" in e for e in ctx.errors) + + def test_included_with_separate_m0_file_errors( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """M0Type=Included but separate m0_json provided -> error.""" + proc = make_processor() + ctx = make_context(m0_type="Included") + m0_data = {"EchoTime": 0.012} + group = { + "asl_json": ("asl.json", {"M0Type": "Included"}), + "m0_json": ("m0.json", m0_data), + "tsv": None, + } + proc._validate_m0_data(group, ctx, "asl.json", group["asl_json"][1]) + assert any("Included" in e for e in ctx.errors) + + def test_separate_with_m0_file_no_error( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """M0Type=Separate with m0_json present -> no contradiction. + + Both ASL and M0 dicts agree on the five compare_params fields to avoid + spurious errors from that helper. + """ + proc = make_processor() + ctx = make_context(m0_type="Separate") + m0_data = { + "EchoTime": 0.012, + "FlipAngle": 90, + "MagneticFieldStrength": 3, + "MRAcquisitionType": "3D", + "PulseSequenceType": "GRASE", + } + asl_data = dict(m0_data, M0Type="Separate") + group = { + "asl_json": ("asl.json", asl_data), + "m0_json": ("m0.json", m0_data), + "tsv": None, + } + proc._validate_m0_data(group, ctx, "asl.json", asl_data) + m0_type_errors = [ + e for e in ctx.errors if "M0 type" in e or "specified as" in e + ] + assert m0_type_errors == [] + + +# ---------- TSV: _analyze_tsv_volume_types and _validate_m0scan_consistency ---------- + + +class TestTSVValidation: + def test_absent_with_m0scan_in_tsv_errors( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """M0Type=Absent but TSV contains 'm0scan' -> error.""" + proc = make_processor() + ctx = make_context(m0_type="Absent") + asl_data = {"M0Type": "Absent"} + tsv_data = ["m0scan", "control", "label"] + proc._analyze_tsv_volume_types( + tsv_data, ctx, "asl.json", asl_data, "context.tsv" + ) + assert any("Absent" in e and "m0scan" in e for e in ctx.errors) + + def test_separate_with_m0scan_in_tsv_errors( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """M0Type=Separate but TSV contains 'm0scan' -> error.""" + proc = make_processor() + ctx = make_context(m0_type="Separate") + asl_data = {"M0Type": "Separate"} + tsv_data = ["m0scan", "control", "label"] + proc._analyze_tsv_volume_types( + tsv_data, ctx, "asl.json", asl_data, "context.tsv" + ) + assert any("Separate" in e and "m0scan" in e for e in ctx.errors) + + def test_total_acquired_pairs_set( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """A clean TSV populates TotalAcquiredPairs.""" + proc = make_processor() + ctx = make_context(m0_type="Included") + asl_data = { + "M0Type": "Included", + "RepetitionTimePreparation": 4.0, + "BackgroundSuppression": False, + } + tsv_data = ["m0scan", "control", "label", "control", "label"] + proc._analyze_tsv_volume_types( + tsv_data, ctx, "asl.json", asl_data, "context.tsv" + ) + assert asl_data["TotalAcquiredPairs"] == 2 # two control-label pairs + + +# ---------- _handle_no_m0scan_warnings ---------- + + +class TestBackgroundSuppressionWarnings: + def test_bs_off_no_warning( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """BackgroundSuppression off -> no warnings.""" + proc = make_processor() + ctx = make_context() + asl_data = {"BackgroundSuppression": False} + proc._handle_no_m0scan_warnings(ctx, "asl.json", asl_data) + assert ctx.warnings == [] + + def test_bs_on_with_pulse_time_warns_about_efficiency( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """BS on with pulse times -> efficiency warning.""" + proc = make_processor() + ctx = make_context() + asl_data = { + "BackgroundSuppression": True, + "BackgroundSuppressionPulseTime": [0.15, 0.5], + } + proc._handle_no_m0scan_warnings(ctx, "asl.json", asl_data) + assert any("BS-pulse efficiency" in w for w in ctx.warnings) + + def test_bs_on_no_pulse_time_warns_about_relative_quantification( + self, + make_processor: Callable[..., ASLProcessor], + make_context: Callable[..., ProcessingContext], + ) -> None: + """BS on without pulse times -> relative-quantification warning.""" + proc = make_processor() + ctx = make_context() + asl_data = {"BackgroundSuppression": True} + proc._handle_no_m0scan_warnings(ctx, "asl.json", asl_data) + assert any("relative quantification" in w for w in ctx.warnings) diff --git a/package/tests/test_normalization.py b/package/tests/test_normalization.py new file mode 100644 index 00000000..cb9a5f25 --- /dev/null +++ b/package/tests/test_normalization.py @@ -0,0 +1,156 @@ +"""Tests for normalization logic in ASLProcessor.""" + +from typing import Callable + +import pytest + +from pyaslreport.modalities.asl.processor import ASLProcessor +from pyaslreport.utils.math_utils import MathUtils +from pyaslreport.utils.unit_conversion_utils import UnitConverterUtils + +# ---------- _rename_fields ---------- + + +@pytest.mark.parametrize( + "old_key,new_key,value", + [ + ("RepetitionTime", "RepetitionTimePreparation", 4.5), + ("InversionTime", "PostLabelingDelay", 1.8), + ("BolusDuration", "BolusCutOffDelayTime", 0.7), + ("InitialPostLabelDelay", "PostLabelingDelay", 1.8), + ], +) +def test_rename_fields_renames_and_deletes_legacy( + make_processor: Callable[..., ASLProcessor], + old_key: str, + new_key: str, + value: float, +) -> None: + """Each of the four mappings copies old->new and deletes the old key.""" + proc = make_processor() + session = {old_key: value} + proc._rename_fields(session) + assert session[new_key] == value + assert old_key not in session + + +def test_rename_fields_numrfblocks_derives_labeling_duration( + make_processor: Callable[..., ASLProcessor], +) -> None: + """NumRFBlocks=100 derives LabelingDuration=1.84 (pins the numeric contract).""" + proc = make_processor() + session = {"NumRFBlocks": 100} + proc._rename_fields(session) + assert session["LabelingDuration"] == pytest.approx(1.84) + + +def test_rename_fields_numrfblocks_retains_source( + make_processor: Callable[..., ASLProcessor], +) -> None: + """NumRFBlocks is retained after deriving LabelingDuration for provenance.""" + proc = make_processor() + session = {"NumRFBlocks": 100} + proc._rename_fields(session) + assert "NumRFBlocks" in session + + +def test_rename_fields_no_legacy_keys_is_noop( + make_processor: Callable[..., ASLProcessor], +) -> None: + """A modern session with no legacy keys is unchanged.""" + proc = make_processor() + session = {"PostLabelingDelay": 1.8, "RepetitionTimePreparation": 4.0} + original = dict(session) + proc._rename_fields(session) + assert session == original + + +# ---------- _convert_units_to_milliseconds ---------- + +TIME_FIELDS = [ + "EchoTime", + "RepetitionTimePreparation", + "LabelingDuration", + "BolusCutOffDelayTime", + "BackgroundSuppressionPulseTime", + "PostLabelingDelay", +] + + +@pytest.mark.parametrize("field", TIME_FIELDS) +def test_convert_scalar( + make_processor: Callable[..., ASLProcessor], field: str +) -> None: + """Each time field scalar is multiplied by 1000.""" + proc = make_processor() + session = {field: 1.5} + proc._convert_units_to_milliseconds(session) + assert session[field] == 1500 + + +@pytest.mark.parametrize("field", TIME_FIELDS) +def test_convert_list(make_processor: Callable[..., ASLProcessor], field: str) -> None: + """Each time field list maps element-wise.""" + proc = make_processor() + session = {field: [1.0, 2.0, 3.0]} + proc._convert_units_to_milliseconds(session) + assert session[field] == [1000, 2000, 3000] + + +def test_convert_leaves_non_time_fields( + make_processor: Callable[..., ASLProcessor], +) -> None: + """Non-time fields untouched; time field converted.""" + proc = make_processor() + session = {"FlipAngle": 90, "EchoTime": 0.012} + proc._convert_units_to_milliseconds(session) + assert session["FlipAngle"] == 90 + assert session["EchoTime"] == 12 + + +# ---------- MathUtils.round_if_close ---------- + + +@pytest.mark.parametrize( + "input_val,expected", + [ + (2000.0, 2000), + (2000.0000001, 2000), + (2000.5, 2000.5), + (1999.9999999, 2000), + (2000.123456, 2000.123), + (0.0, 0), + (-2000.0, -2000), + ], +) +def test_round_if_close(input_val: float, expected: int | float) -> None: + """Snaps to int within 1e-6 of an integer; else rounds to 3 decimals.""" + result = MathUtils.round_if_close(input_val) + assert result == expected + if isinstance(expected, int): + assert isinstance(result, int) + + +# ---------- UnitConverterUtils.convert_to_milliseconds ---------- + + +def test_convert_to_ms_scalar() -> None: + """A scalar in seconds becomes milliseconds (int when close to integer).""" + assert UnitConverterUtils.convert_to_milliseconds(2.0) == 2000 + + +def test_convert_to_ms_list() -> None: + """A list maps element-wise.""" + assert UnitConverterUtils.convert_to_milliseconds([1.0, 2.0]) == [1000, 2000] + + +def test_convert_to_ms_rejects_string() -> None: + """A non-numeric scalar raises TypeError.""" + with pytest.raises(TypeError): + UnitConverterUtils.convert_to_milliseconds("not a number") + + +def test_convert_to_ms_rejects_list_with_string() -> None: + """A list containing a non-number raises TypeError.""" + with pytest.raises(TypeError): + UnitConverterUtils.convert_to_milliseconds([1.0, "two"]) diff --git a/package/src/pyaslreport/tests/test_package.py b/package/tests/test_package.py similarity index 63% rename from package/src/pyaslreport/tests/test_package.py rename to package/tests/test_package.py index fda43d33..ae89ca5a 100644 --- a/package/src/pyaslreport/tests/test_package.py +++ b/package/tests/test_package.py @@ -1,38 +1,50 @@ +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock -from .main import get_bids_metadata + +from pyaslreport.main import get_bids_metadata # filepath: /home/ibrahim/MyPc/Projects/GSoC/ASL-Parameter-Generator/package/src/pyaslreport/test_main.py + def test_get_bids_metadata_success(): data = {"modality": "asl", "dicom_dir": "/fake/dir"} fake_header = MagicMock() fake_sequence = MagicMock() fake_sequence.extract_bids_metadata.return_value = ("meta", "context") - with patch("pyaslreport.main.get_dicom_header", return_value=fake_header), \ - patch("pyaslreport.main.get_sequence", return_value=fake_sequence): + with ( + patch("pyaslreport.main.get_dicom_header", return_value=fake_header), + patch("pyaslreport.main.get_sequence", return_value=fake_sequence), + ): result = get_bids_metadata(data) assert result == ("meta", "context") fake_sequence.extract_bids_metadata.assert_called_once() + def test_get_bids_metadata_no_dicom_dir(): data = {"modality": "asl"} with pytest.raises(TypeError): get_bids_metadata(data) + def test_get_bids_metadata_no_sequence(): data = {"modality": "asl", "dicom_dir": "/fake/dir"} fake_header = MagicMock() - with patch("pyaslreport.main.get_dicom_header", return_value=fake_header), \ - patch("pyaslreport.main.get_sequence", return_value=None): + with ( + patch("pyaslreport.main.get_dicom_header", return_value=fake_header), + patch("pyaslreport.main.get_sequence", return_value=None), + ): with pytest.raises(ValueError) as exc: get_bids_metadata(data) assert "No matching sequence found" in str(exc.value) + def test_get_bids_metadata_invalid_modality(): data = {"modality": None, "dicom_dir": "/fake/dir"} fake_header = MagicMock() - with patch("pyaslreport.main.get_dicom_header", return_value=fake_header), \ - patch("pyaslreport.main.get_sequence", return_value=None): + with ( + patch("pyaslreport.main.get_dicom_header", return_value=fake_header), + patch("pyaslreport.main.get_sequence", return_value=None), + ): with pytest.raises(ValueError): - get_bids_metadata(data) \ No newline at end of file + get_bids_metadata(data) diff --git a/package/tests/test_validators.py b/package/tests/test_validators.py new file mode 100644 index 00000000..953c4b48 --- /dev/null +++ b/package/tests/test_validators.py @@ -0,0 +1,210 @@ +"""Tests for the six validator classes plus schema loading.""" + +from typing import Any + +from pyaslreport.core.config import config +from pyaslreport.modalities.asl.validators import ( + BooleanValidator, + ConsistencyValidator, + NumberArrayValidator, + NumberOrNumberArrayValidator, + NumberValidator, + StringValidator, +) + +MAJOR_SLOT, ERROR_SLOT, WARNING_SLOT = 0, 2, 4 + + +def fired_major(r: tuple[Any, ...]) -> bool: + """True if the major-error slot is populated.""" + return r[MAJOR_SLOT] is not None + + +def fired_error(r: tuple[Any, ...]) -> bool: + """True if the (non-major) error slot is populated.""" + return r[ERROR_SLOT] is not None + + +def fired_warning(r: tuple[Any, ...]) -> bool: + """True if the warning slot is populated.""" + return r[WARNING_SLOT] is not None + + +def all_clear(r: tuple[Any, ...]) -> bool: + """True if no slot is populated.""" + return all(x is None for x in r) + + +class TestNumberValidator: + def test_within_bounds(self) -> None: + assert all_clear(NumberValidator(min_error=0, max_error=10).validate(5)) + + def test_min_error_is_strict(self) -> None: + v = NumberValidator(min_error=0) + assert v.validate(0)[ERROR_SLOT] == "Value must be > 0" + assert all_clear(v.validate(0.0001)) + + def test_above_max_error(self) -> None: + assert fired_error(NumberValidator(max_error=10).validate(20)) + + def test_min_error_include_is_inclusive(self) -> None: + v = NumberValidator(min_error_include=0) + assert all_clear(v.validate(0)) + assert fired_error(v.validate(-1)) + + def test_max_error_include_boundary(self) -> None: + v = NumberValidator(max_error_include=360) + assert all_clear(v.validate(360)) + assert v.validate(360.001)[ERROR_SLOT] == "Value must be <= 360" + + def test_warning_threshold(self) -> None: + assert fired_warning(NumberValidator(min_warning=0).validate(-1)) + + def test_enforce_integer_alone(self) -> None: + v = NumberValidator(enforce_integer=True) + assert all_clear(v.validate(5)) + assert fired_error(v.validate(5.5)) + + def test_enforce_integer_rule_precedes_range(self) -> None: + """Integer rule is added first, so it fires before the range check.""" + v = NumberValidator(min_error=0, enforce_integer=True) + assert v.validate(2.5)[ERROR_SLOT] == "Value must be an integer" + assert v.validate(-1)[ERROR_SLOT] == "Value must be > 0" + assert all_clear(v.validate(3)) + + +class TestStringValidator: + def test_allowed_passes(self) -> None: + assert all_clear( + StringValidator(allowed_values=["PCASL", "PASL"]).validate("PCASL") + ) + + def test_case_insensitive(self) -> None: + assert all_clear(StringValidator(allowed_values=["PCASL"]).validate("pcasl")) + + def test_disallowed_routes_to_error(self) -> None: + r = StringValidator(allowed_values=["PCASL"]).validate("XYZ") + assert fired_error(r) and not fired_major(r) + + def test_major_flag_routes_to_major(self) -> None: + r = StringValidator(allowed_values=["PCASL"], major_error=True).validate("XYZ") + assert fired_major(r) and not fired_error(r) + + def test_no_allowed_values_accepts_anything(self) -> None: + assert all_clear(StringValidator().validate("anything")) + + +class TestBooleanValidator: + def test_true_false_pass(self) -> None: + assert all_clear(BooleanValidator().validate(True)) + assert all_clear(BooleanValidator().validate(False)) + + def test_string_errors(self) -> None: + assert fired_error(BooleanValidator().validate("true")) + + def test_int_errors(self) -> None: + """isinstance(1, bool) is False, so 1 is rejected.""" + assert fired_error(BooleanValidator().validate(1)) + + +class TestNumberArrayValidator: + def test_exact_size(self) -> None: + v = NumberArrayValidator(size_error=3) + assert all_clear(v.validate([1, 2, 3])) + assert fired_error(v.validate([1, 2])) + assert fired_error(v.validate([1, 2, "x"])) + + def test_min_error_skips_non_numbers(self) -> None: + """Range rule's guarded comprehension ignores non-numeric elements.""" + v = NumberArrayValidator(min_error=0) + assert fired_error(v.validate([1, -1, 2])) + assert all_clear(v.validate([1, 2, "x"])) + + def test_ascending_allows_equal_neighbors(self) -> None: + v = NumberArrayValidator(check_ascending=True) + assert all_clear(v.validate([1, 1, 2])) + assert fired_error(v.validate([3, 1, 2])) + + +class TestNumberOrNumberArrayValidator: + def test_scalar_dispatch(self) -> None: + v = NumberOrNumberArrayValidator(min_error=0) + assert fired_error(v.validate(-1)) + assert all_clear(v.validate(5)) + + def test_array_dispatch(self) -> None: + v = NumberOrNumberArrayValidator(min_error=0) + assert fired_error(v.validate([1, -1, 2])) + assert all_clear(v.validate([1, 2, 3])) + + def test_wrong_type_is_reported_as_major_error(self) -> None: + """A non-numeric, non-list value is reported as a major error by design.""" + r = NumberOrNumberArrayValidator().validate("x") + assert fired_major(r) and not fired_error(r) + + +class TestConsistencyValidator: + def test_string_same_passes(self) -> None: + v = ConsistencyValidator(validation_type="string") + assert all_clear(v.validate([("PCASL", "a"), ("PCASL", "b")])) + + def test_string_mismatch_error_tier(self) -> None: + v = ConsistencyValidator(validation_type="string", is_major=False) + assert fired_error(v.validate([("PCASL", "a"), ("PASL", "b")])) + + def test_string_mismatch_major_tier(self) -> None: + v = ConsistencyValidator(validation_type="string", is_major=True) + assert fired_major(v.validate([("PCASL", "a"), ("PASL", "b")])) + + def test_float_within_warning(self) -> None: + v = ConsistencyValidator( + "floatOrArray", error_variation=10, warning_variation=0.1 + ) + assert all_clear(v.validate([(2000, "a"), (2000.05, "b")])) + + def test_float_crosses_warning(self) -> None: + v = ConsistencyValidator( + "floatOrArray", error_variation=10, warning_variation=0.1 + ) + assert fired_warning(v.validate([(2000, "a"), (2000.5, "b")])) + + def test_float_crosses_error(self) -> None: + v = ConsistencyValidator( + "floatOrArray", error_variation=10, warning_variation=0.1 + ) + assert fired_error(v.validate([(2000, "a"), (2020, "b")])) + + def test_boolean_same_passes(self) -> None: + assert all_clear( + ConsistencyValidator("boolean").validate([(True, "a"), (True, "b")]) + ) + + def test_boolean_mixed_errors(self) -> None: + assert fired_error( + ConsistencyValidator("boolean").validate([(True, "a"), (False, "b")]) + ) + + +def test_all_expected_schemas_loaded() -> None: + """config['schemas'] contains every shipped schema.""" + expected = { + "major_error_schema", + "required_validator_schema", + "required_condition_schema", + "recommended_validator_schema", + "recommended_condition_schema", + "consistency_schema", + } + assert expected.issubset(config["schemas"].keys()) + + +def test_major_error_schema_fields() -> None: + """The four documented major-error fields are present.""" + schema = config["schemas"]["major_error_schema"] + for field in ( + "PLDType", + "ArterialSpinLabelingType", + "MRAcquisitionType", + "PulseSequenceType", + ): + assert field in schema