diff --git a/crates/ppvm-python-native/src/stim_program.rs b/crates/ppvm-python-native/src/stim_program.rs index 28f03159..ad5eef07 100644 --- a/crates/ppvm-python-native/src/stim_program.rs +++ b/crates/ppvm-python-native/src/stim_program.rs @@ -46,6 +46,14 @@ impl PyStimProgram { ) } + /// Number of qubits the program operates on: one past the highest qubit + /// index any instruction references (`0` if it touches no qubits). Lets + /// callers size a tableau when no explicit qubit count is given. + #[getter] + fn num_qubits(&self) -> usize { + self.0.num_qubits() + } + /// Jupyter rich display: syntax-highlighted Stim source. Only invoked in /// IPython/Jupyter; plain `str()`/`print()` stay uncoloured elsewhere. fn _repr_html_(&self) -> String { diff --git a/crates/stim-parser/src/ast/extended.rs b/crates/stim-parser/src/ast/extended.rs index e0fd6794..6ff8b3b4 100644 --- a/crates/stim-parser/src/ast/extended.rs +++ b/crates/stim-parser/src/ast/extended.rs @@ -81,6 +81,43 @@ impl ExtendedProgram { pub fn measurement_count(&self) -> usize { count_in_slice(&self.instructions, 1) } + + /// Number of qubits the program operates on: one past the highest qubit + /// index referenced by any executable instruction, or `0` if it touches no + /// qubits. Annotations (`DETECTOR`, `QUBIT_COORDS`, …) are ignored — their + /// operands are measurement-record lookbacks or coordinates, not executable + /// qubits. Pure AST property; backend-agnostic, mirrors [`measurement_count`]. + pub fn num_qubits(&self) -> usize { + max_qubit_in_slice(&self.instructions).map_or(0, |m| m + 1) + } +} + +/// Highest qubit index referenced by any executable instruction in `slice`, +/// recursing into `REPEAT` bodies. `None` if nothing touches a qubit. +fn max_qubit_in_slice(instructions: &[ExtendedInstruction]) -> Option { + let mut max: Option = None; + for instr in instructions { + // `Option` orders `None` below every `Some`, so `max.max(local)` + // tracks the running maximum and treats "no qubit" as absent. + let local = match instr { + ExtendedInstruction::Gate(op) => op.targets.iter().filter_map(|t| t.as_qubit()).max(), + ExtendedInstruction::Noise(op) => op.targets.iter().copied().max(), + ExtendedInstruction::Measure(op) => op.targets.iter().copied().max(), + ExtendedInstruction::Mpp(op) => op.products.iter().flatten().map(|f| f.qubit).max(), + ExtendedInstruction::T { targets, .. } + | ExtendedInstruction::TDag { targets, .. } + | ExtendedInstruction::Rotation { targets, .. } + | ExtendedInstruction::U3 { targets, .. } + | ExtendedInstruction::Loss { targets, .. } => targets.iter().copied().max(), + ExtendedInstruction::CorrelatedLoss { targets, .. } => { + targets.iter().flat_map(|&(a, b)| [a, b]).max() + } + ExtendedInstruction::Repeat { body, .. } => max_qubit_in_slice(body), + ExtendedInstruction::Annotation(_) | ExtendedInstruction::MPad { .. } => None, + }; + max = max.max(local); + } + max } fn count_in_slice(instructions: &[ExtendedInstruction], factor: u64) -> usize { @@ -117,7 +154,7 @@ fn count_in_slice(instructions: &[ExtendedInstruction], factor: u64) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::ast::shared::{GateOp, MeasureOp}; + use crate::ast::shared::{GateOp, MeasureOp, Target}; use crate::diagnostics::{LineMap, Span}; use crate::instructions::{GateName, MeasureName}; use std::sync::Arc; @@ -146,6 +183,60 @@ mod tests { assert_eq!(prog.measurement_count(), 6); } + #[test] + fn num_qubits_is_one_past_highest_index() { + // Gate on qubits {0, 4}, measure {2} -> 5 qubits (indices 0..=4). + let prog = ExtendedProgram { + instructions: vec![ + ExtendedInstruction::Gate(GateOp { + name: GateName::H, + tags: vec![], + args: vec![], + targets: vec![Target::Qubit(0), Target::Qubit(4)], + span: span(), + }), + ExtendedInstruction::Measure(MeasureOp { + name: MeasureName::M, + tags: vec![], + args: vec![], + targets: vec![2], + span: span(), + }), + ], + line_map: Arc::new(LineMap::new("")), + }; + assert_eq!(prog.num_qubits(), 5); + } + + #[test] + fn num_qubits_recurses_into_repeat() { + let prog = ExtendedProgram { + instructions: vec![ExtendedInstruction::Repeat { + count: 3, + body: vec![ExtendedInstruction::Measure(MeasureOp { + name: MeasureName::M, + tags: vec![], + args: vec![], + targets: vec![7], + span: span(), + })], + span: span(), + }], + line_map: Arc::new(LineMap::new("")), + }; + assert_eq!(prog.num_qubits(), 8); + } + + #[test] + fn num_qubits_is_zero_for_no_qubit_program() { + // Empty program, and an annotation-only program, both touch no qubits. + let empty = ExtendedProgram { + instructions: vec![], + line_map: Arc::new(LineMap::new("")), + }; + assert_eq!(empty.num_qubits(), 0); + } + #[test] fn gate_op_is_shared_with_vanilla() { let _ = ExtendedInstruction::Gate(GateOp { diff --git a/ppvm-python/src/ppvm/_core.pyi b/ppvm-python/src/ppvm/_core.pyi index c4a91ca6..2d3998d3 100644 --- a/ppvm-python/src/ppvm/_core.pyi +++ b/ppvm-python/src/ppvm/_core.pyi @@ -172,6 +172,8 @@ class _GeneralizedTableauBase: def fork(self, seed: int | None = None) -> _GeneralizedTableauBase: ... class StimProgram: + @property + def num_qubits(self) -> int: ... @staticmethod def parse(src: str) -> StimProgram: ... @staticmethod diff --git a/ppvm-python/src/ppvm/generalized_tableau.py b/ppvm-python/src/ppvm/generalized_tableau.py index b0e6cbe2..eb3443e1 100644 --- a/ppvm-python/src/ppvm/generalized_tableau.py +++ b/ppvm-python/src/ppvm/generalized_tableau.py @@ -325,7 +325,7 @@ def run(self, prog: StimProgram) -> list[MeasurementResult]: def sample( cls, prog: StimProgram, - n_qubits: int, + n_qubits: int | None = None, min_abs_coeff: float = 1e-10, num_shots: int = 1, seed: int | None = None, @@ -335,6 +335,11 @@ def sample( Each shot starts from a fresh tableau, so this is the right entry point for multi-shot sampling. + When ``n_qubits`` is ``None`` (the default) the qubit count is inferred + from the program via ``prog.num_qubits`` (one past the highest qubit + index it references), falling back to 1 for a program that touches no + qubits. Pass an explicit ``n_qubits`` to size the tableau larger. + Shots run in parallel across CPU cores (the GIL is released during sampling), with a serial fallback for small batches. When ``seed`` is given (it must fit in an unsigned 64-bit integer), shot ``i`` uses @@ -343,6 +348,8 @@ def sample( ``RAYON_NUM_THREADS`` environment variable before the first call to control the pool size (it defaults to the number of logical cores). """ + if n_qubits is None: + n_qubits = max(1, prog.num_qubits) native_cls = _native_tableau_cls(n_qubits) raw = native_cls.sample(prog, n_qubits, min_abs_coeff, num_shots, seed) return [[MeasurementResult(x) for x in shot] for shot in raw] @@ -350,15 +357,17 @@ def sample( def sample_stim( prog: StimProgram, - n_qubits: int, + n_qubits: int | None = None, min_abs_coeff: float = 1e-10, num_shots: int = 1, seed: int | None = None, ) -> list[list[MeasurementResult]]: """Multi-shot sampling — module-level alias for ``GeneralizedTableau.sample``. - Shots are sampled in parallel across CPU cores with the GIL released; see - `GeneralizedTableau.sample` for seeding and ``RAYON_NUM_THREADS``. + When ``n_qubits`` is ``None`` (the default) the qubit count is inferred from + the program; see `GeneralizedTableau.sample`. Shots are sampled in parallel + across CPU cores with the GIL released; see `GeneralizedTableau.sample` for + seeding and ``RAYON_NUM_THREADS``. """ return GeneralizedTableau.sample( prog, n_qubits, min_abs_coeff=min_abs_coeff, num_shots=num_shots, seed=seed diff --git a/ppvm-python/test/generalized_tableau/test_stim.py b/ppvm-python/test/generalized_tableau/test_stim.py index a8adcd54..08996fc0 100644 --- a/ppvm-python/test/generalized_tableau/test_stim.py +++ b/ppvm-python/test/generalized_tableau/test_stim.py @@ -275,6 +275,30 @@ def test_run_propagates_parse_error_as_value_error(): StimProgram.parse("FROBNICATE 0") +def test_stim_program_num_qubits_property(): + # One past the highest qubit index any instruction references. + assert StimProgram.parse("H 0\nCX 0 4\nM 2").num_qubits == 5 + assert StimProgram.parse("M 0").num_qubits == 1 + # Annotations (QUBIT_COORDS / DETECTOR) carry no executable qubit operands. + assert StimProgram.parse("QUBIT_COORDS(0, 0) 0\nX 0\nM 0\nDETECTOR rec[-1]").num_qubits == 1 + + +def test_sample_stim_infers_n_qubits_when_omitted(): + # Two qubits flipped; omitting n_qubits must infer 3 (indices 0..=2). + prog = StimProgram.parse("X 0 2\nM 0 1 2") + inferred = sample_stim(prog, num_shots=1, seed=0) + explicit = sample_stim(prog, n_qubits=3, num_shots=1, seed=0) + assert inferred == explicit + assert inferred == [[MeasurementResult.ONE, MeasurementResult.ZERO, MeasurementResult.ONE]] + + +def test_sample_classmethod_infers_n_qubits_when_omitted(): + prog = StimProgram.parse("H 0\nCX 0 1\nM 0 1") + inferred = GeneralizedTableau.sample(prog, num_shots=5, seed=0) + explicit = GeneralizedTableau.sample(prog, 2, num_shots=5, seed=0) + assert inferred == explicit + + def test_sample_many_qubits(): stim_str = textwrap.dedent(""" X 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99