Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/instrumentserver/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,12 @@ def toJson(self) -> Dict[str, Any]:
ret["message"] = self.message.toJson()
elif hasattr(self.message, "attributes"):
ret["message"] = _convert_arbitrary_obj_to_dict(self.message)
# A top-level Enum/IntFlag message (e.g. the return of a parameter get)
# is serialized with its _class_type so the client can reconstruct the
# actual enum instance, not just its value. Must come before the generic
# Iterable check (Flag members are iterable on Python >= 3.11).
elif isinstance(self.message, Enum):
ret["message"] = _convert_enum_to_dict(self.message)
elif not isinstance(self.message, str) and isinstance(self.message, Iterable):
if isinstance(self.message, dict):
message_dict = dict_to_serialized_dict(self.message)
Expand All @@ -723,6 +729,19 @@ def toJson(self) -> Dict[str, Any]:
return ret


def _convert_enum_to_dict(enum_member: Enum) -> Dict[str, Any]:
"""
Converts an Enum/IntFlag member into a serialized dictionary that can be
reconstructed by ``_convert_dict_to_obj``. The enum class must be importable
on the deserializing side, since reconstruction does ``EnumClass(value=...)``.
"""
cls = type(enum_member)
return {
"value": enum_member.value,
"_class_type": f"{cls.__module__}.{cls.__qualname__}",
}


def _convert_arbitrary_obj_to_dict(obj: object) -> Dict[str, Any]:
"""
Converts an arbitrary objects into a dictionary. Assumes that the object contains an attribute called
Expand Down Expand Up @@ -792,6 +811,13 @@ def iterable_to_serialized_dict(
serialized_iterable = dict_to_serialized_dict(dct=item)
converted_iterable.append(serialized_iterable)

# Enum/IntFlag members are treated as scalars. This must come before
# the generic Iterable check: since Python 3.11 a Flag member is
# iterable and a single-bit member iterates to itself, which would
# otherwise recurse forever.
elif isinstance(item, Enum):
converted_iterable.append(str(item.value))

elif not isinstance(item, str) and isinstance(item, Iterable):
serialized_iterable = iterable_to_serialized_dict(iterable=item)
converted_iterable.append(serialized_iterable)
Expand Down Expand Up @@ -834,6 +860,13 @@ def dict_to_serialized_dict(
serialized_iterable = dict_to_serialized_dict(dct=value)
converted_dict[name] = serialized_iterable

# Enum/IntFlag members are treated as scalars. This must come before
# the generic Iterable check: since Python 3.11 a Flag member is
# iterable and a single-bit member iterates to itself, which would
# otherwise recurse forever.
elif isinstance(value, Enum):
converted_dict[name] = str(value.value)

elif not isinstance(value, str) and isinstance(value, Iterable):
serialized_iterable = iterable_to_serialized_dict(iterable=value)
converted_dict[name] = serialized_iterable
Expand Down
47 changes: 47 additions & 0 deletions src/instrumentserver/testing/dummy_instruments/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,27 @@
# No need to mypy check dummy testing instruments.

import time
from enum import IntFlag
from typing import List

import numpy as np
from qcodes import Instrument, validators
from qcodes.math_utils.field_vector import FieldVector


class StatusFlag(IntFlag):
"""An ``IntFlag`` mirroring drivers like the Yokogawa GS200 status byte.

On Python >= 3.11 flag members are iterable and a single-bit member
iterates to itself, which is what triggers the snapshot serialization
recursion bug this instrument is used to test.
"""

EAV = 1 << 2
MAV = 1 << 4
ESB = 1 << 5


class DummyChannel(Instrument):
def __init__(self, name: str, *args, **kwargs):
super().__init__(name, *args, **kwargs)
Expand Down Expand Up @@ -250,3 +264,36 @@ def set_complex_list(self, value):
def generic_function(self):
print("this generic function has been called")
return 3


class DummyInstrumentWithFlags(Instrument):
"""Dummy instrument whose parameters return ``IntFlag`` values.

Mirrors the Yokogawa GS200 status/event register parameters, which store
``IntFlag`` instances as their value. Used to test that station snapshots
containing flag values serialize (and round-trip) correctly.
"""

def __init__(self, name, *args, **kwargs):
super().__init__(name=name, *args, **kwargs)

self._status = StatusFlag.EAV
self._condition = StatusFlag.EAV | StatusFlag.ESB

self.add_parameter(
name="status_byte",
label="Status Byte",
get_cmd=self.get_status,
)

self.add_parameter(
name="condition_register",
label="Condition Register",
get_cmd=self.get_condition,
)

def get_status(self):
return self._status

def get_condition(self):
return self._condition
225 changes: 225 additions & 0 deletions test/pytest/test_enum_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""Tests for serializing ``Enum``/``IntFlag`` valued parameters.

Regression coverage for the infinite-recursion bug triggered when a station
snapshot contains an ``IntFlag`` value (e.g. the Yokogawa GS200 status/event
registers, whose ``get_parser`` returns ``IntFlag`` instances).

Starting with Python 3.11, ``Flag``/``IntFlag`` members are *iterable*, and a
single-bit member iterates to a one-element sequence containing *itself*::

>>> list(StatusByte.EAV)
[<StatusByte.EAV: 4>]
>>> list(StatusByte.EAV)[0] is StatusByte.EAV
True

The serializers (``dict_to_serialized_dict`` / ``iterable_to_serialized_dict``)
recurse into anything that is a non-``str`` ``Iterable``, so before the fix they
recursed into a flag value forever and raised
``RecursionError: maximum recursion depth exceeded`` from ``encode()`` on the
server's send path.

The fix treats any ``Enum`` as a scalar (serializing its ``.value``) instead of
an iterable. These tests assert that scalar behaviour and that no recursion
occurs. They require Python >= 3.11 to reproduce the original failure, but the
expected serialized output is version independent.
"""

import json
from enum import Enum, IntFlag

from instrumentserver.base import encode
from instrumentserver.blueprints import (
ServerResponse,
deserialize_obj,
dict_to_serialized_dict,
iterable_to_serialized_dict,
)
from instrumentserver.testing.dummy_instruments.generic import StatusFlag


class StatusByte(IntFlag):
"""Mirror of the Yokogawa GS200 status byte flag."""

EAV = 1 << 2
MAV = 1 << 4
ESB = 1 << 5


class SourceMode(Enum):
"""A plain (non-flag) enum with string values."""

VOLT = "VOLT"
CURR = "CURR"


def test_intflag_in_dict_serializes_to_value():
"""A single-bit IntFlag value serializes to its underlying int, no recursion."""
result = dict_to_serialized_dict({"status_byte": StatusByte.EAV})
assert result == {"status_byte": str(StatusByte.EAV.value)} # "4"


def test_intflag_in_iterable_serializes_to_value():
result = iterable_to_serialized_dict([StatusByte.EAV])
assert result == [str(StatusByte.EAV.value)] # ["4"]


def test_combined_intflag_serializes_to_combined_value():
"""A multi-bit flag must not be split into its members; keep the int value."""
combined = StatusByte.EAV | StatusByte.MAV # value == 20
result = dict_to_serialized_dict({"status_byte": combined})
assert result == {"status_byte": str(combined.value)} # "20"


def test_plain_enum_serializes_to_value():
result = dict_to_serialized_dict({"source_mode": SourceMode.VOLT})
assert result == {"source_mode": "VOLT"}


def test_nested_snapshot_with_intflag_does_not_recurse():
"""A snapshot-shaped nested dict with IntFlag values serializes fully."""
snapshot = {
"parameters": {
"status_byte": {
"name": "status_byte",
"unit": "",
"value": StatusByte.EAV,
"raw_value": StatusByte.EAV,
},
"voltage": {
"name": "voltage",
"unit": "V",
"value": 0.5,
},
}
}

result = dict_to_serialized_dict(snapshot)

status = result["parameters"]["status_byte"]
assert status["value"] == str(StatusByte.EAV.value)
assert status["raw_value"] == str(StatusByte.EAV.value)
assert result["parameters"]["voltage"]["value"] == "0.5"


def test_encode_server_response_with_intflag_snapshot():
"""End-to-end: the actual failing path ``encode(ServerResponse(...))``.

This is what the server runs in ``send_router``; before the fix it raised
``RecursionError`` on Python 3.11+.
"""
snapshot = {
"parameters": {
"status_byte": {"name": "status_byte", "value": StatusByte.EAV},
"condition_register": {
"name": "condition_register",
"value": StatusByte.EAV | StatusByte.ESB,
},
}
}
response = ServerResponse(message=snapshot)

# Must not raise RecursionError, and must produce valid JSON.
payload = encode(response)
decoded = json.loads(payload)
assert decoded["_class_type"] == "ServerResponse"
assert "status_byte" in decoded["message"]


def test_server_response_message_with_intflag_round_trips():
"""The serialized payload survives a decode back into Python objects."""
snapshot = {"parameters": {"status_byte": {"value": StatusByte.EAV}}}
response = ServerResponse(message=snapshot)

reconstructed = deserialize_obj(json.loads(encode(response)))
assert reconstructed._class_type == "ServerResponse"


def test_server_response_enum_message_reconstructs_enum_type():
"""A top-level Enum message round-trips back into the *actual* enum type.

Unlike a snapshot (where a nested flag is serialized lossily to its value),
a parameter get returns the flag as the message itself. In that case the
client should receive a reconstructed ``StatusFlag`` instance, not a bare
int. ``StatusFlag`` lives in an importable module so deserialization can
rebuild it via its ``_class_type``.
"""
response = ServerResponse(message=StatusFlag.EAV)

reconstructed = deserialize_obj(json.loads(encode(response)))

assert isinstance(reconstructed.message, StatusFlag)
assert reconstructed.message == StatusFlag.EAV


def test_server_response_composite_enum_message_reconstructs_enum_type():
"""A multi-bit flag value also reconstructs to the composite enum member."""
combined = StatusFlag.EAV | StatusFlag.ESB
response = ServerResponse(message=combined)

reconstructed = deserialize_obj(json.loads(encode(response)))

assert isinstance(reconstructed.message, StatusFlag)
assert reconstructed.message == combined


# --- Full end-to-end round-trip through the real client/server -------------


FLAG_INSTRUMENT_CLASS = (
"instrumentserver.testing.dummy_instruments.generic.DummyInstrumentWithFlags"
)


def test_end_to_end_get_intflag_parameter(cli):
"""A live client getting an IntFlag parameter receives the real enum type.

This is the case the user cares about: calling the parameter (not
snapshotting) should return a reconstructed ``StatusFlag`` instance on the
client side, not just its integer value.
"""
flag_ins = cli.find_or_create_instrument("flag_ins", FLAG_INSTRUMENT_CLASS)

status = flag_ins.status_byte()
assert isinstance(status, StatusFlag)
assert status == StatusFlag.EAV

condition = flag_ins.condition_register()
assert isinstance(condition, StatusFlag)
assert condition == (StatusFlag.EAV | StatusFlag.ESB)


def test_end_to_end_snapshot_with_intflag(cli):
"""The exact failing scenario: snapshotting an instrument with IntFlag values.

Before the fix this raised RecursionError server-side and the client got no
response. Here we assert the snapshot comes back intact through the full
serialize -> send -> deserialize pipeline.
"""
ins = cli.find_or_create_instrument("flag_ins_snap", FLAG_INSTRUMENT_CLASS)

# update=True forces the parameters to be read into the snapshot
snapshot = ins.get_snapshot(update=True)

assert isinstance(snapshot, dict)
params = snapshot["parameters"]
assert params["status_byte"]["value"] == StatusFlag.EAV.value # 4
assert (
params["condition_register"]["value"]
== (StatusFlag.EAV | StatusFlag.ESB).value # 36
)


def test_end_to_end_full_station_snapshot_with_intflag(cli):
"""Snapshot of the whole station (no instrument arg) including flag values.

This mirrors the user's reported call (``snapshot`` of the station with a
Yokogawa present) most closely.
"""
cli.find_or_create_instrument("flag_ins_station", FLAG_INSTRUMENT_CLASS)

station_snapshot = cli.get_snapshot(update=True)

assert isinstance(station_snapshot, dict)
instruments = station_snapshot["instruments"]
flag_params = instruments["flag_ins_station"]["parameters"]
assert flag_params["status_byte"]["value"] == StatusFlag.EAV.value
Loading