diff --git a/.gitignore b/.gitignore index bfb34fe..b34d1dd 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,6 @@ Thumbs.db # IDEs .vscode/ .cursor/ + +# Other +notes.* diff --git a/pyproject.toml b/pyproject.toml index f1e8460..7ca1b9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "ophyd", "ophyd-async[ca] >=0.19", "h5py", + "xrayutilities>=1.7.12,<2" ] [project.optional-dependencies] @@ -45,6 +46,7 @@ test = [ "tiled[minimal-client]", "tiled[minimal-server]", "ophyd >=v1.10.6", + "pytest-watcher", ] dev = [ "caproto[standard] >=0.4.2rc1,!=1.2.0", @@ -92,15 +94,12 @@ dev-dependencies = [ [tool.pytest.ini_options] minversion = "6.0" -addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +# TODO - fix the eiger_async module and tests +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config", "--ignore=tests/test_eiger_async.py"] xfail_strict = true -filterwarnings = [ - "ignore", -] +filterwarnings = "ignore" log_cli_level = "INFO" -testpaths = [ - "tests", -] +testpaths = "tests" [tool.coverage] run.source = ["cditools"] @@ -164,11 +163,17 @@ isort.required-imports = ["from __future__ import annotations"] [tool.pixi.project] channels = ["conda-forge"] -platforms = ["linux-64"] +platforms = ["linux-64", "osx-arm64"] [tool.pixi.pypi-dependencies] cditools = { path = ".", editable = true } +[tool.pixi.dependencies] +xrayutilities = ">=1.7.12,<2" + +[tool.pixi.feature.dev.tasks] +format = "ruff format ." + [tool.pixi.environments] default = { solve-group = "default" } dev = { features = ["dev"], solve-group = "default" } diff --git a/src/cditools/attenuator.py b/src/cditools/attenuator.py new file mode 100644 index 0000000..f466ae7 --- /dev/null +++ b/src/cditools/attenuator.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import asyncio +import math +from collections import OrderedDict +from dataclasses import dataclass + +import numpy as np +import xrayutilities as xu +from event_model import DataKey # type: ignore[import-untyped] +from ophyd_async.core import ( + AsyncMovable, + AsyncStatus, + DeviceVector, + StandardReadable, + StrictEnum, +) +from ophyd_async.epics.core import EpicsDevice, epics_signal_r, epics_signal_rw + +from cditools.motors import Energy + + +@dataclass +class AttenuatorCombination: + transmission: float + attenuators: list[int] + + @property + def attenuation(self): + return 1 - self.transmission + + +class AttenuatorStatusEnum(StrictEnum): + LOW = "Low" # off / not obstructing + HIGH = "High" # on / obstructing + + +class Attenuator(EpicsDevice, AsyncMovable[AttenuatorStatusEnum]): + def __init__(self, prefix: str, num: int, material: str, thickness: float): + """ + Parameters + ---------- + prefix : str + The common prefix for the attenuator bank + num : int + An integer denoting which attenuator within the bank this is + thickness : float + The thickness of the attenuator in microns + + Attributes + ---------- + position : SignalRW[AttenuatorStatusEnum] + The read / write PV to open and close the attenuator and get + the current state of the attenuator + mode : SignalRW[bool] + in_status : SignalR[AttenuatorStatusEnum] + """ + self.prefix = prefix + self.num = num + self.filter_material = getattr(xu.materials, material) + self.thickness = thickness # microns + + self.position = epics_signal_rw( + AttenuatorStatusEnum, + f"{self.prefix}:DO{self.num}-Sts", + write_pv=f"{self.prefix}:DO{self.num}-Cmd", + ) + self.mode = epics_signal_rw(bool, f"{self.prefix}:DIO{self.num}-Mode") + self.in_status = epics_signal_r( + AttenuatorStatusEnum, f"{self.prefix}:DI{self.num}-Sts" + ) + + super().__init__(prefix=self.prefix) + + def __repr__(self): + return f"{self.thickness!s} microns, {self.filter_material}" + + @property + def name(self): + return f"attenuator_{self.num}" + + @AsyncStatus.wrap + async def set(self, value: AttenuatorStatusEnum): + await self.position.set(value) + + async def open(self): + """Open means open to allowing the beam to pass unobstructed""" + await self.position.set(AttenuatorStatusEnum.LOW) + + async def close(self): + """Closed means obstructing the beam""" + await self.position.set(AttenuatorStatusEnum.HIGH) + + def attenuation(self, photon_energy: float, egu: str = "KeV"): + """Attenuation is the fraction of the beam removed""" + return 1 - self.transmission(photon_energy, egu=egu) + + def transmission(self, photon_energy: float, egu: str = "KeV"): + """Transmission is the fraction of beam remaining""" + abs_len = self._absorption_length(photon_energy, egu=egu) + return np.exp(-self.thickness / abs_len) + + def _absorption_length(self, photon_energy: float, egu: str = "KeV") -> float: + """ + Calculates L, the characteristic absorption length of this material, + at this beam energy. + + Parameters + ---------- + photon energy : float + The beam energy + egu : {'KeV', 'eV'} + The engineering units of the beam energy + + Returns + ------- + float + The characteristic absorption length of the filter material (microns) + """ + if egu == "KeV": + photon_energy = photon_energy * 1e3 + elif egu != "eV": + msg = f"Photon energy units must be eV or KeV (not {egu=})" + raise RuntimeError(msg) + return self.filter_material.absorption_length(photon_energy) # type: ignore[reportArgumentType] + + +class AttenuatorBank(StandardReadable, EpicsDevice, AsyncMovable[float]): + """ + The ioc for the iologik1 lives on xf09id1-inst-ioc1.nsls2.bnl.gov + """ + + def __init__( + self, prefix: str, atten_configs: list[tuple[str, float]], energy: Energy + ): + self.prefix = prefix + self.energy = energy + + with self.add_children_as_readables(): + self.attenuators = DeviceVector( + { + i: Attenuator(self.prefix, i, material, thickness) + for i, (material, thickness) in enumerate(atten_configs, start=1) + } + ) + super().__init__(prefix=self.prefix) + + def get_photon_energy(self): + return self.energy.energy.readback.get() + + def get_egu(self): + return self.energy.egu + + async def read(self): # type: ignore[reportUnknownParameterType] + """ + Polls the bluesky energy object for the current beam energy, and + returns that energy, each filter position, each transmission, and + the total transmission. + """ + status = OrderedDict() + active_attens = [] + energy = self.get_photon_energy() + egu = self.get_egu() + positions = await asyncio.gather( + *(a.position.get_value() for _, a in self.attenuators.items()) + ) + for i, pos in zip(self.attenuators, positions): + atten = self.attenuators[i] + is_active = pos == AttenuatorStatusEnum.HIGH + if is_active: + active_attens.append(atten) + transmission = atten.transmission(energy, egu) if is_active else 0 + status[atten.name] = {"active": is_active, "transmission": transmission} + status["photon_energy"] = energy + status["egu"] = egu + status["total_transmission"] = self._calculate_total_transmission( + energy, *active_attens + ) + return status + + async def describe(self) -> OrderedDict[str, DataKey]: + """Describe the structure of values returned by read().""" + + description = OrderedDict() + + for atten in self.attenuators.values(): + description[atten.name] = DataKey( + source=atten.position.source, + dtype="string", + shape=[], + ) + energy_source = getattr( + self.energy.energy.readback, + "source", + f"ca://{self.prefix}:photon_energy", + ) + description["photon_energy"] = DataKey( + source=energy_source, + dtype="number", + shape=[], + ) + description["egu"] = DataKey( + source=f"ca://{self.prefix}:egu", + dtype="string", + shape=[], + ) + description["total_transmission"] = DataKey( + source=f"ca://{self.prefix}:total_transmission", + dtype="number", + shape=[], + ) + + return description + + @AsyncStatus.wrap + async def set(self, value: float): + """Set the transmission for the attenuator bank""" + attenuation_combination = self.find_closest_transmission( + self.get_photon_energy(), value + ) + coros = [] + for ( + num, + atten, + ) in self.attenuators.items(): + if num in attenuation_combination.attenuators: + coros.append(atten.close()) + else: + coros.append(atten.open()) + await asyncio.gather(*coros) + + def find_closest_transmission( + self, photon_energy: float, target_transmission: float + ) -> AttenuatorCombination: + available_attenuations = self._calculate_available_transmissions(photon_energy) + best_idx = np.argmin( + [abs(target_transmission - _.transmission) for _ in available_attenuations] + ) + return available_attenuations[best_idx] + + def _calculate_available_transmissions( + self, photon_energy: float + ) -> list[AttenuatorCombination]: + """ + Calculates all possible transmissions for the attenuator bank, using + the powerset of the available attenuators. The result is not sorted, + as it is more efficient to scan linearly each time for the closest + match. + """ + available_transmissions = [] + for combination in self._powerset(): + attens = [self.attenuators[a] for a in self.attenuators if a in combination] + total_transmission = self._calculate_total_transmission( + photon_energy, *attens + ) + available_transmissions.append( + AttenuatorCombination(total_transmission, combination) + ) + return available_transmissions + + def _calculate_total_transmission( + self, photon_energy: float, *attenuators: Attenuator + ) -> float: + transmissions = [ + a.transmission(photon_energy, self.get_egu()) for a in attenuators + ] + return round(float(math.prod(transmissions)), 3) + + def _powerset(self) -> list[list[int]]: + """ + This is a famously O(n*2^n) problem. + """ + powerset = [] + for i in range(1 << len(self.attenuators)): + combination = [] + for j in range(len(self.attenuators)): + if i & (1 << j): + combination.append(j + 1) # +1 because attenuators are 1-indexed + powerset.append(combination) + return powerset diff --git a/tests/test_attenuator.py b/tests/test_attenuator.py new file mode 100644 index 0000000..63ea0ca --- /dev/null +++ b/tests/test_attenuator.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +import pytest_asyncio +from ophyd_async.core import get_mock_put, init_devices, set_mock_value +from ophyd_async.testing import assert_value + +from cditools.attenuator import ( + Attenuator, + AttenuatorBank, + AttenuatorCombination, + AttenuatorStatusEnum, +) +from cditools.motors import Energy + +pytest_plugins = ("pytest_asyncio",) +photon_energy = 8.6 # KeV +prefix = "test-prefix" +attenuator_configs = [("Al", 16.0), ("Al", 24.0), ("Al", 66.0), ("Al", 124.0)] + +# These are the attenuations at photon_energy = 8.6 KeV +TEST_ATTENUATIONS = [ + AttenuatorCombination(transmission=0.084, attenuators=[1, 2, 3, 4]), + AttenuatorCombination(transmission=0.1, attenuators=[2, 3, 4]), + AttenuatorCombination(transmission=0.109, attenuators=[1, 3, 4]), + AttenuatorCombination(transmission=0.129, attenuators=[3, 4]), + AttenuatorCombination(transmission=0.171, attenuators=[1, 2, 4]), + AttenuatorCombination(transmission=0.203, attenuators=[2, 4]), + AttenuatorCombination(transmission=0.222, attenuators=[1, 4]), + AttenuatorCombination(transmission=0.263, attenuators=[4]), + AttenuatorCombination(transmission=0.32, attenuators=[1, 2, 3]), + AttenuatorCombination(transmission=0.38, attenuators=[2, 3]), + AttenuatorCombination(transmission=0.414, attenuators=[1, 3]), + AttenuatorCombination(transmission=0.492, attenuators=[3]), + AttenuatorCombination(transmission=0.65, attenuators=[1, 2]), + AttenuatorCombination(transmission=0.772, attenuators=[2]), + AttenuatorCombination(transmission=0.842, attenuators=[1]), + AttenuatorCombination(transmission=1.0, attenuators=[]), +] + + +@pytest_asyncio.fixture +async def mock_attenuator_bank(): + async with init_devices(mock=True): + mock_energy = MagicMock(spec=Energy) + mock_energy.energy.readback.get.return_value = photon_energy + mock_energy.egu = "KeV" + mock_attenuator_bank = AttenuatorBank(prefix, attenuator_configs, mock_energy) + yield mock_attenuator_bank + + +@pytest_asyncio.fixture +async def mock_attenuator(mock_attenuator_bank: AttenuatorBank): + async with init_devices(mock=True): + mock_attenuator = Attenuator(mock_attenuator_bank.prefix, 1, "Al", 16) + yield mock_attenuator + + +class TestAttenuator: + @pytest.mark.asyncio + async def test_open(self, mock_attenuator: Attenuator): + set_mock_value(mock_attenuator.position, AttenuatorStatusEnum.HIGH) + await mock_attenuator.open() + await assert_value(mock_attenuator.position, AttenuatorStatusEnum.LOW) + + @pytest.mark.asyncio + async def test_close(self, mock_attenuator: Attenuator): + set_mock_value(mock_attenuator.position, AttenuatorStatusEnum.LOW) + await mock_attenuator.close() + await assert_value(mock_attenuator.position, AttenuatorStatusEnum.HIGH) + + def test_transmission_kev(self, mock_attenuator: Attenuator): + assert mock_attenuator.transmission(photon_energy) == pytest.approx( + 0.84, abs=0.01 + ) + + def test_transmission_ev(self, mock_attenuator: Attenuator): + photon_energy = 8600 # eV + assert mock_attenuator.transmission(photon_energy, egu="eV") == pytest.approx( + 0.84, abs=0.01 + ) + + def test_attenuation_kev(self, mock_attenuator: Attenuator): + assert mock_attenuator.attenuation(photon_energy) == pytest.approx( + 0.16, abs=0.01 + ) + + def test_attenuation_ev(self, mock_attenuator: Attenuator): + photon_energy = 8600 # eV + assert mock_attenuator.attenuation(photon_energy, egu="eV") == pytest.approx( + 0.16, abs=0.01 + ) + + +class TestAttenuatorBank: + @pytest.mark.asyncio + async def test_attenuation_bank_creation( + self, mock_attenuator_bank: AttenuatorBank + ): + assert mock_attenuator_bank.energy.energy.readback.get() == 8.6 + # assert mock_attenuator_bank.photon_energy == 8.6 + + second_energy = MagicMock(spec=Energy) + second_energy.energy.readback.get.return_value = 6 + second_bank = AttenuatorBank(prefix, attenuator_configs, second_energy) + assert second_bank.energy.energy.readback.get() == 6 + # assert second_bank.photon_energy == 6 + + @pytest.mark.asyncio + async def test_attenuators_indexed_at_1(self, mock_attenuator_bank: AttenuatorBank): + with pytest.raises(KeyError): + mock_attenuator_bank.attenuators[0] + + atten1 = mock_attenuator_bank.attenuators[1] + assert atten1.num == 1 + assert atten1.thickness == 16 + assert atten1.position.source == "mock+ca://test-prefix:DO1-Sts" + assert atten1.mode.source == "mock+ca://test-prefix:DIO1-Mode" + assert atten1.in_status.source == "mock+ca://test-prefix:DI1-Sts" + assert atten1.name == "attenuator_1" + + atten2 = mock_attenuator_bank.attenuators[2] + assert atten2.num == 2 + assert atten2.thickness == 24 + + atten3 = mock_attenuator_bank.attenuators[3] + assert atten3.num == 3 + assert atten3.thickness == 66 + + atten4 = mock_attenuator_bank.attenuators[4] + assert atten4.num == 4 + assert atten4.thickness == 124 + + @pytest.mark.asyncio + async def test_set_attenuation(self, mock_attenuator_bank: AttenuatorBank): + atten_mock1 = get_mock_put(mock_attenuator_bank.attenuators[1].position) + atten_mock2 = get_mock_put(mock_attenuator_bank.attenuators[2].position) + atten_mock3 = get_mock_put(mock_attenuator_bank.attenuators[3].position) + atten_mock4 = get_mock_put(mock_attenuator_bank.attenuators[4].position) + + combo0 = TEST_ATTENUATIONS[1] # attenuators 2, 3, 4 + await mock_attenuator_bank.set(combo0.transmission) + atten_mock1.assert_called_with(AttenuatorStatusEnum.LOW) + atten_mock2.assert_called_with(AttenuatorStatusEnum.HIGH) + atten_mock3.assert_called_with(AttenuatorStatusEnum.HIGH) + atten_mock4.assert_called_with(AttenuatorStatusEnum.HIGH) + + combo1 = TEST_ATTENUATIONS[-3] # attenuator 2 + await mock_attenuator_bank.set(combo1.transmission) + atten_mock1.assert_called_with(AttenuatorStatusEnum.LOW) + atten_mock2.assert_called_with(AttenuatorStatusEnum.HIGH) + atten_mock3.assert_called_with(AttenuatorStatusEnum.LOW) + atten_mock4.assert_called_with(AttenuatorStatusEnum.LOW) + + @pytest.mark.asyncio + async def test_read(self, mock_attenuator_bank: AttenuatorBank): + mock_attenuator_bank.set(1) + status = await mock_attenuator_bank.read() + assert status == { + "photon_energy": 8.6, + "egu": "KeV", + "total_transmission": 1, + "attenuator_1": {"active": False, "transmission": 0}, + "attenuator_2": {"active": False, "transmission": 0}, + "attenuator_3": {"active": False, "transmission": 0}, + "attenuator_4": {"active": False, "transmission": 0}, + } + + # Test with different energy and attenuations + async with init_devices(mock=True): + second_energy = MagicMock(spec=Energy) + second_energy.energy.readback.get.return_value = 12 + second_energy.egu = "KeV" + second_bank = AttenuatorBank(prefix, attenuator_configs, second_energy) + set_mock_value(second_bank.attenuators[1].position, AttenuatorStatusEnum.LOW) + set_mock_value(second_bank.attenuators[2].position, AttenuatorStatusEnum.HIGH) + set_mock_value(second_bank.attenuators[3].position, AttenuatorStatusEnum.HIGH) + set_mock_value(second_bank.attenuators[4].position, AttenuatorStatusEnum.LOW) + + status = await second_bank.read() + assert status == { + "attenuator_1": {"active": False, "transmission": 0}, + "attenuator_2": { + "active": True, + "transmission": pytest.approx(0.909, rel=0.001), + }, + "attenuator_3": { + "active": True, + "transmission": pytest.approx(0.769, rel=0.001), + }, + "attenuator_4": {"active": False, "transmission": 0}, + "photon_energy": 12, + "egu": "KeV", + "total_transmission": pytest.approx(0.699), + } + + @pytest.mark.asyncio + async def test_describe(self, mock_attenuator_bank: AttenuatorBank): + description = await mock_attenuator_bank.describe() + + expected_keys = { + "attenuator_1", + "attenuator_2", + "attenuator_3", + "attenuator_4", + "photon_energy", + "egu", + "total_transmission", + } + assert set(description.keys()) == expected_keys + + for i in range(1, 5): + assert description[f"attenuator_{i}"] == { + "source": mock_attenuator_bank.attenuators[i].position.source, + "dtype": "string", + "shape": [], + } + + assert description["photon_energy"] == { + "source": mock_attenuator_bank.energy.energy.readback.source, + "dtype": "number", + "shape": [], + } + assert description["egu"] == { + "source": f"ca://{mock_attenuator_bank.prefix}:egu", + "dtype": "string", + "shape": [], + } + assert description["total_transmission"] == { + "source": f"ca://{mock_attenuator_bank.prefix}:total_transmission", + "dtype": "number", + "shape": [], + } + + def test_find_closest_attenuation(self, mock_attenuator_bank: AttenuatorBank): + en = mock_attenuator_bank.energy.energy.readback.get() + nearest = mock_attenuator_bank.find_closest_transmission(en, 0.7) + assert nearest.transmission == 0.65 + + nearest2 = mock_attenuator_bank.find_closest_transmission(en, 0.2) + assert nearest2.transmission == 0.203 + + nearest3 = mock_attenuator_bank.find_closest_transmission(en, 0.02) + assert nearest3.transmission == 0.084 + + nearest4 = mock_attenuator_bank.find_closest_transmission(en, 0.98) + assert nearest4.transmission == 1 + + def test_find_closest_attenuation_with_alt_energies( + self, mock_attenuator_bank: AttenuatorBank + ): + en = mock_attenuator_bank.energy.energy.readback.get() + nearest = mock_attenuator_bank.find_closest_transmission(en, 0.7) + assert nearest == AttenuatorCombination(transmission=0.65, attenuators=[1, 2])