diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b86269..f1bc2c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Types of changes: ### Fixed - Fixed classical register declarations not being visible inside `box` scope, causing "Missing clbit register declaration" errors for measurement statements inside box blocks. ([#306](https://github.com/qBraid/pyqasm/pull/306)) +- Fixed the backend-dependent `dt` duration unit being incorrectly relabeled as `ns` when unrolling `delay` and `box` statements without a `device_cycle_time`. Since `dt` cannot be converted to SI units without a sample rate, it is now preserved as `dt`. ([#301](https://github.com/qBraid/pyqasm/issues/301)) ### Dependencies - Bumped `actions/configure-pages` from 5 to 6. ([#307](https://github.com/qBraid/pyqasm/pull/307)) diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index e4c5dec..a9a565f 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -2800,6 +2800,23 @@ def _evaluate_case(statements): default_stmts = statement.default.statements return _evaluate_case(default_stmts) + def _resolve_duration_unit(self, time_var) -> qasm3_ast.TimeUnit: + """Determine the output unit for a duration literal. + + `dt` is backend-dependent and not convertible to SI units without a known + sample rate. Preserve it when the source unit was `dt` (or a device cycle + time is set). SI units are already converted to ns by the evaluator. + """ + source_is_dt = ( + isinstance(time_var, qasm3_ast.DurationLiteral) + and time_var.unit == qasm3_ast.TimeUnit.dt + ) + return ( + qasm3_ast.TimeUnit.dt + if self._module._device_cycle_time or source_is_dt + else qasm3_ast.TimeUnit.ns + ) + def _visit_delay_statement( self, statement: qasm3_ast.DelayInstruction ) -> list[qasm3_ast.Statement]: @@ -2828,12 +2845,7 @@ def _visit_delay_statement( if duration_val: PulseValidator.validate_duration_literal_value(duration_val, statement) statement.duration = qasm3_ast.DurationLiteral( - duration_val, - unit=( - qasm3_ast.TimeUnit.dt - if self._module._device_cycle_time - else qasm3_ast.TimeUnit.ns - ), + duration_val, unit=self._resolve_duration_unit(_delay_time_var) ) if self._scope_manager.in_box_scope(): @@ -2903,11 +2915,7 @@ def _visit_box_statement(self, statement: qasm3_ast.Box) -> list[qasm3_ast.State PulseValidator.validate_duration_literal_value(box_duration_val, statement) statement.duration = qasm3_ast.DurationLiteral( box_duration_val, - unit=( - qasm3_ast.TimeUnit.dt - if self._module._device_cycle_time - else qasm3_ast.TimeUnit.ns - ), + unit=self._resolve_duration_unit(_box_time_var), ) self._scope_manager.push_scope({}) self._scope_manager.increment_scope_level() @@ -2933,7 +2941,7 @@ def _visit_box_statement(self, statement: qasm3_ast.Box) -> list[qasm3_ast.State and box_duration_val and self._total_delay_duration_in_box > box_duration_val ): - time_unit = "dt" if self._module._device_cycle_time else "ns" + time_unit = self._resolve_duration_unit(_box_time_var).name raise_qasm3_error( f"Total delay duration value '{self._total_delay_duration_in_box}{time_unit}' " f"should be less than 'box[{box_duration_val}{time_unit}]' duration.", diff --git a/tests/qasm3/test_box.py b/tests/qasm3/test_box.py index c566c3b..f233c5c 100644 --- a/tests/qasm3/test_box.py +++ b/tests/qasm3/test_box.py @@ -102,6 +102,33 @@ def test_delay_instruction_device_time(): check_unrolled_qasm(dumps(module), expected_qasm) +def test_box_dt_unit_preserved(): + """A ``dt`` box duration must be preserved as ``dt`` (not relabeled ``ns``) + when no ``device_cycle_time`` is set, since ``dt`` is backend-dependent. + """ + qasm_str = """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[1] q; + box[200dt] { + delay[100dt] q[0]; + x q[0]; + } + """ + expected_qasm = """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[1] q; + box[200.0dt] { + delay[100.0dt] q[0]; + x q[0]; + } + """ + module = loads(qasm_str) + module.unroll() + check_unrolled_qasm(dumps(module), expected_qasm) + + @pytest.mark.parametrize( "qasm_code,error_message,error_span", [ @@ -199,6 +226,21 @@ def test_delay_instruction_device_time(): r"Total delay duration value '20.0ns' should be less than 'box[10.0ns]' duration.", r"Error at line 6, column 12", ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[5] q; + qubit[2] q2; + box[10dt] { + delay[20dt]; + x q[1]; + measure q; + } + """, + r"Total delay duration value '20.0dt' should be less than 'box[10.0dt]' duration.", + r"Error at line 6, column 12", + ), ( """ OPENQASM 3.0; diff --git a/tests/qasm3/test_delay.py b/tests/qasm3/test_delay.py index ceff3b9..f444e11 100644 --- a/tests/qasm3/test_delay.py +++ b/tests/qasm3/test_delay.py @@ -66,6 +66,30 @@ def test_delay_instruction_device_time(): check_unrolled_qasm(dumps(module), expected_qasm) +def test_delay_dt_unit_preserved(): + """A ``dt`` delay literal must be preserved as ``dt`` (not relabeled ``ns``) + when no ``device_cycle_time`` is set, since ``dt`` is backend-dependent. SI + units are still converted to ns. + """ + qasm_str = """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[2] q; + delay[100dt] q[0]; + delay[2us] q[1]; + """ + expected_qasm = """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[2] q; + delay[100.0dt] q[0]; + delay[2000.0ns] q[1]; + """ + module = loads(qasm_str) + module.unroll() + check_unrolled_qasm(dumps(module), expected_qasm) + + @pytest.mark.parametrize( "qasm_code,error_message,error_span", [