Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion crates/ppvm-stim/tests/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,23 @@ fn rx_pi_flips_qubit() {

#[test]
fn u3_pi_flip_via_y_axis() {
let (results, _) = run("I[U3(theta=1.0*pi, phi=0.0, lambda=0.0)] 0\nM 0", 1);
let (results, _) = run("I[U3(theta=1.0*pi, phi=0.0*pi, lambda=0.0*pi)] 0\nM 0", 1);
assert_eq!(results, vec![Some(true)]);
}

#[test]
fn u3_all_angles_nonzero_exercises_phi_lambda() {
// U3(theta=pi, phi=pi/2, lambda=pi/2) == Y (clifft Rz(phi)Ry(theta)Rz(lambda)),
// so H·U3·H == H·Y·H == -Y and |0> -> |1> deterministically. The H frame makes
// the outcome sensitive to phi *and* lambda (drop or mis-scale either and P(1)
// collapses to ~0.5). Half-turn tag args 1.0/0.5/0.5 each get *pi at lowering.
let tag = run(
"H 0\nI[U3(theta=1.0*pi, phi=0.5*pi, lambda=0.5*pi)] 0\nH 0\nM 0",
1,
);
assert_eq!(tag.0, vec![Some(true)]);
}

#[test]
fn t_gate_via_s_t_tag_no_op_on_zero() {
let (results, _) = run("S[T] 0\nM 0", 1);
Expand Down
9 changes: 8 additions & 1 deletion crates/stim-parser/src/ast/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,14 @@ pub struct Tag {
#[derive(Debug, Clone, PartialEq)]
pub enum TagParam {
Positional(f64),
Named { key: String, value: f64 },
/// A `key=value` tag parameter. `had_pi` records whether the value was
/// written as a `<n>*pi` (or bare `pi`) expression — rotation/U3 tags
/// require it (half-turn convention), and the printer re-emits `*pi`.
Named {
key: String,
value: f64,
had_pi: bool,
},
}

/// The rotation axis for an extended-dialect `R_X` / `R_Y` / `R_Z` rotation.
Expand Down
75 changes: 68 additions & 7 deletions crates/stim-parser/src/pipeline/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ fn lower_gate(
match name {
// Native T / T_DAG mnemonics lower to the same sugar as `S[T]` / `S_DAG[T]`.
T | TDag => {
if let Some(tag) = tags.first() {
return invalid_tag(
&tag.name,
name.canonical_name(),
span,
"bare T/T_DAG take no tags; use S[T] / S_DAG[T] for the tagged form",
sink,
);
}
let Some(targets) = qubit_targets(targets, name.canonical_name(), span, sink)? else {
return Ok(None);
};
Expand Down Expand Up @@ -445,7 +454,7 @@ fn exact_named_params<const N: usize>(
sink,
);
}
TagParam::Named { key, value } => {
TagParam::Named { key, value, had_pi } => {
let Some(index) = required.iter().position(|required_key| key == required_key)
else {
return invalid_tag(
Expand All @@ -465,6 +474,17 @@ fn exact_named_params<const N: usize>(
sink,
);
}
// Rotation/U3 angles are in half-turns: require the `<n>*pi` form,
// mirroring tsim, so a bare number can't be mistaken for radians.
if !had_pi {
return invalid_tag(
&tag.name,
instruction,
span,
format!("parameter '{key}' must be written as <n>*pi (half-turns)"),
sink,
);
}
seen[index] = true;
values[index] = *value;
}
Expand Down Expand Up @@ -552,7 +572,7 @@ mod tests {

#[test]
fn identity_rotation_x_lowers() {
let prog = lower_extended("I[R_X(theta=0.5)] 0").expect("lower");
let prog = lower_extended("I[R_X(theta=0.5*pi)] 0").expect("lower");
match &prog.instructions[0] {
ExtendedInstruction::Rotation {
axis,
Expand All @@ -561,7 +581,7 @@ mod tests {
..
} => {
assert_eq!(*axis, Axis::X);
assert_eq!(*theta, 0.5);
assert!((*theta - 0.5 * std::f64::consts::PI).abs() < 1e-12);
assert_eq!(targets, &vec![0]);
}
other => panic!("{other:?}"),
Expand All @@ -570,7 +590,8 @@ mod tests {

#[test]
fn identity_u3_lowers() {
let prog = lower_extended("I[U3(theta=0.5, phi=1.0, lambda=1.5)] 0").expect("lower");
let prog =
lower_extended("I[U3(theta=0.5*pi, phi=1.0*pi, lambda=1.5*pi)] 0").expect("lower");
match &prog.instructions[0] {
ExtendedInstruction::U3 {
theta,
Expand All @@ -579,15 +600,55 @@ mod tests {
targets,
..
} => {
assert_eq!(*theta, 0.5);
assert_eq!(*phi, 1.0);
assert_eq!(*lambda, 1.5);
let pi = std::f64::consts::PI;
assert!((*theta - 0.5 * pi).abs() < 1e-12);
assert!((*phi - 1.0 * pi).abs() < 1e-12);
assert!((*lambda - 1.5 * pi).abs() < 1e-12);
assert_eq!(targets, &vec![0]);
}
other => panic!("{other:?}"),
}
}

#[test]
fn bare_t_with_tag_is_rejected() {
// A tag on bare T/T_DAG is meaningless (the tagged form is S[T]); reject
// it rather than silently dropping it.
let err = lower_extended("T[foo] 0").unwrap_err();
assert_eq!(err.last().unwrap().code, Some("invalid-tag"));
}

#[test]
fn bare_t_dag_with_tag_is_rejected() {
let err = lower_extended("T_DAG[foo] 0").unwrap_err();
assert_eq!(err.last().unwrap().code, Some("invalid-tag"));
}

#[test]
fn rotation_tag_without_pi_is_rejected() {
// Mirror tsim: rotation tag angles must be written as <n>*pi (half-turns).
let err = lower_extended("I[R_Z(theta=0.5)] 0").unwrap_err();
assert_eq!(err.last().unwrap().code, Some("invalid-tag"));
}

#[test]
fn u3_tag_without_pi_is_rejected() {
let err = lower_extended("I[U3(theta=0.5, phi=1.0, lambda=1.5)] 0").unwrap_err();
assert_eq!(err.last().unwrap().code, Some("invalid-tag"));
}

#[test]
fn rotation_tag_with_pi_is_accepted() {
let prog = lower_extended("I[R_Z(theta=0.5*pi)] 0").expect("lower");
match &prog.instructions[0] {
ExtendedInstruction::Rotation { axis, theta, .. } => {
assert_eq!(*axis, Axis::Z);
assert!((*theta - 0.5 * std::f64::consts::PI).abs() < 1e-12);
}
other => panic!("{other:?}"),
}
}

#[test]
fn i_error_loss_lowers() {
let prog = lower_extended("I_ERROR[loss](0.01) 0").expect("lower");
Expand Down
3 changes: 2 additions & 1 deletion crates/stim-parser/src/pipeline/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ mod tests {
TagParam::Named {
key: "theta".to_string(),
value: 0.25,
had_pi: false,
},
],
}],
Expand All @@ -559,7 +560,7 @@ mod tests {
assert!(matches!(tags[0].params[0], TagParam::Positional(0.5)));
assert!(matches!(
&tags[0].params[1],
TagParam::Named { key, value } if key == "theta" && *value == 0.25
TagParam::Named { key, value, .. } if key == "theta" && *value == 0.25
));
}
other => panic!("{other:?}"),
Expand Down
88 changes: 79 additions & 9 deletions crates/stim-parser/src/print/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ fn write_tags(out: &mut dyn fmt::Write, tags: &[Tag]) -> fmt::Result {
}
match p {
TagParam::Positional(v) => write!(out, "{}", FloatLit(*v))?,
TagParam::Named { key, value } => {
write!(out, "{key}={}", FloatLit(*value))?;
TagParam::Named { key, value, had_pi } => {
if *had_pi {
write!(out, "{key}={}*pi", FloatLit(pi_coeff(*value)))?;
} else {
write!(out, "{key}={}", FloatLit(*value))?;
}
}
}
}
Expand Down Expand Up @@ -165,6 +169,37 @@ impl fmt::Display for FloatLit {
}
}

/// The coefficient `c` to print for a `<c>*pi` literal carrying the radians
/// `value`. The naive `value / PI` is correct but, because the division
/// rounds, often prints a long tail (`0.76*pi` → `0.7599999999999999*pi`).
///
/// Parser-produced angles always originate from a decimal coefficient — a
/// `<n>*pi` rotation/U3 tag — so a short decimal `c` with
/// `c * PI == value` (bit-for-bit) exists. We return the shortest such `c`,
/// which prints cleanly *and* re-parses back to exactly `value`. Requiring
/// exact equality is what keeps `parse → print` lossless and the printer a
/// fixpoint; for any `value` with no exact short form we fall back to the
/// naive `value / PI` (same output as before).
fn pi_coeff(value: f64) -> f64 {
let pi = std::f64::consts::PI;
let q = value / pi;
if !q.is_finite() {
return q;
}
// `{:.*e}` with `prec` digits after the mantissa point is `prec + 1`
// significant digits; 17 sig-digits round-trips any f64, so by `prec = 16`
// `candidate == q` and the loop has tried every shorter rounding first.
for prec in 0..=16 {
let candidate: f64 = format!("{q:.prec$e}")
.parse()
.expect("a formatted float always re-parses");
if candidate * pi == value {
return candidate;
}
}
q
}

// ---------------------------------------------------------------------------
// StimPrint for shared *Op structs
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -291,7 +326,14 @@ impl StimPrint for ExtendedInstruction {
Axis::Y => "R_Y",
Axis::Z => "R_Z",
};
write!(out, "I[{}(theta={})]", axis_tag, FloatLit(*theta))?;
// theta is radians; re-emit the half-turn `<n>*pi` form the
// rotation tags require (see exact_named_params).
write!(
out,
"I[{}(theta={}*pi)]",
axis_tag,
FloatLit(pi_coeff(*theta))
)?;
write_usize_targets(out, targets)?;
}
ExtendedInstruction::U3 {
Expand All @@ -303,10 +345,10 @@ impl StimPrint for ExtendedInstruction {
} => {
write!(
out,
"I[U3(theta={}, phi={}, lambda={})]",
FloatLit(*theta),
FloatLit(*phi),
FloatLit(*lambda),
"I[U3(theta={}*pi, phi={}*pi, lambda={}*pi)]",
FloatLit(pi_coeff(*theta)),
FloatLit(pi_coeff(*phi)),
FloatLit(pi_coeff(*lambda)),
)?;
write_usize_targets(out, targets)?;
}
Expand Down Expand Up @@ -381,9 +423,9 @@ mod tests {

#[test]
fn extended_printed_form_lowers_sugar_into_canonical_stim() {
let src = "S[T] 0\nI[R_X(theta=0.25)] 1\nI_ERROR[loss](0.01) 2\n";
let src = "S[T] 0\nI[R_X(theta=0.25*pi)] 1\nI_ERROR[loss](0.01) 2\n";
let ast = parse_extended(src).unwrap();
let expected = "S[T] 0\nI[R_X(theta=0.25)] 1\nI_ERROR[loss](0.01) 2\n";
let expected = "S[T] 0\nI[R_X(theta=0.25*pi)] 1\nI_ERROR[loss](0.01) 2\n";
assert_eq!(ast.to_stim(), expected);
}

Expand All @@ -393,4 +435,32 @@ mod tests {
let ast = parse("CX rec[-1] 0\nMPP X0*Y1*Z2\n").unwrap();
assert_eq!(ast.to_stim(), "CX rec[-1] 0\nMPP X0*Y1*Z2\n");
}

#[test]
fn rotation_pi_coeff_prints_clean_and_round_trips() {
// theta is stored in radians as `c*PI`; printing `c = theta/PI` naively
// would emit a rounding tail like `0.7599999999999999*pi`. The printer
// recovers the short coefficient instead — for rotation and U3 tags —
// and `print → parse → print` stays a fixpoint.
for (src, expected) in [
// Non-binary-friendly decimals that `theta/PI` mangles.
("I[R_Z(theta=0.34*pi)] 0\n", "I[R_Z(theta=0.34*pi)] 0\n"),
("I[R_Y(theta=0.76*pi)] 1\n", "I[R_Y(theta=0.76*pi)] 1\n"),
("I[R_X(theta=-2.78*pi)] 2\n", "I[R_X(theta=-2.78*pi)] 2\n"),
(
"I[U3(theta=0.34*pi, phi=0.91*pi, lambda=0.07*pi)] 0\n",
"I[U3(theta=0.34*pi, phi=0.91*pi, lambda=0.07*pi)] 0\n",
),
] {
let printed = parse_extended(src).unwrap().to_stim();
assert_eq!(printed, expected, "first print of {src:?}");
assert!(
!printed.contains("999999") && !printed.contains("000000"),
"coefficient printed with a rounding tail: {printed:?}"
);
// Fixpoint: re-parsing and re-printing reproduces it byte-for-byte.
let reprinted = parse_extended(&printed).unwrap().to_stim();
assert_eq!(reprinted, printed, "printer is not a fixpoint for {src:?}");
}
}
}
25 changes: 17 additions & 8 deletions crates/stim-parser/src/syntax/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,29 @@ pub(crate) fn signed_float<'src>() -> impl Parser<'src, &'src str, f64, Extra<'s
.map(|s: &str| s.parse::<f64>().expect("validated by combinator shape"))
}

/// Pi-expression: `pi`, `<num>*pi`, or plain number. Evaluates to f64.
pub(crate) fn pi_expr<'src>() -> impl Parser<'src, &'src str, f64, Extra<'src>> + Clone {
let pi_kw = just("pi").to(std::f64::consts::PI);
/// Pi-expression, paired with whether `pi` actually appeared in the source:
/// `pi` -> `(PI, true)`, `<num>*pi` -> `(num*PI, true)`, `<num>` -> `(num, false)`.
/// The flag lets rotation/U3 tags enforce the half-turn `<n>*pi` convention.
pub(crate) fn pi_expr_flagged<'src>()
-> impl Parser<'src, &'src str, (f64, bool), Extra<'src>> + Clone {
let pi_kw = just("pi").to((std::f64::consts::PI, true));
let num_then_pi = signed_float()
.then(inline_pad().ignore_then(just("*pi")).or_not())
.map(|(n, suffix)| {
if suffix.is_some() {
n * std::f64::consts::PI
(n * std::f64::consts::PI, true)
} else {
n
(n, false)
}
});
choice((pi_kw, num_then_pi))
}

/// Pi-expression: `pi`, `<num>*pi`, or plain number. Evaluates to f64.
pub(crate) fn pi_expr<'src>() -> impl Parser<'src, &'src str, f64, Extra<'src>> + Clone {
pi_expr_flagged().map(|(value, _)| value)
}

use crate::ast::shared::{Tag, TagParam};

/// `<ident>=<pi_expr>` (Named) or `<pi_expr>` (Positional).
Expand All @@ -105,8 +113,8 @@ pub(crate) fn tag_param<'src>() -> impl Parser<'src, &'src str, TagParam, Extra<
.then_ignore(inline_pad())
.then_ignore(just('='))
.then_ignore(inline_pad())
.then(pi_expr())
.map(|(key, value)| TagParam::Named { key, value });
.then(pi_expr_flagged())
.map(|(key, (value, had_pi))| TagParam::Named { key, value, had_pi });
let positional = pi_expr().map(TagParam::Positional);
choice((named, positional))
}
Expand Down Expand Up @@ -343,9 +351,10 @@ mod tests {
assert_eq!(t.name, "R_X");
assert_eq!(t.params.len(), 1);
match &t.params[0] {
TagParam::Named { key, value } => {
TagParam::Named { key, value, had_pi } => {
assert_eq!(key, "theta");
assert!((value - 0.5 * std::f64::consts::PI).abs() < 1e-12);
assert!(had_pi);
}
other => panic!("{other:?}"),
}
Expand Down
Loading
Loading