From 5f9e8724823a57832631b04a2c66e301fa2bcd93 Mon Sep 17 00:00:00 2001 From: MauroFab Date: Mon, 29 Jun 2026 17:17:42 -0300 Subject: [PATCH 1/6] spike(stark): builder-rewrite (Plan B) capture of transition constraints into the IR Counterpart to PR #737 (symbolic field) for comparing the two capture front-ends. Reuses the same IR + CPU interpreter; adds an explicit IrBuilder + object-safe Capture trait, with capture() implemented on IsBit/Add/ProductZero ALONGSIDE their unchanged evaluate (non-destructive). Diff test matches real evaluate bit-for-bit over 1000 random rows; captured node counts 4-21 (vs A's 66-78, since the builder only emits leaves for columns actually read). CPU-only minimal spike, same scope as #737. --- crypto/stark/src/constraint_ir/builder.rs | 203 ++++++++++++++++++++++ crypto/stark/src/constraint_ir/interp.rs | 97 +++++++++++ crypto/stark/src/constraint_ir/ir.rs | 80 +++++++++ crypto/stark/src/constraint_ir/mod.rs | 39 +++++ crypto/stark/src/lib.rs | 1 + prover/src/constraints/cpu.rs | 11 ++ prover/src/constraints/templates.rs | 135 ++++++++++++++ prover/src/tests/constraint_ir_tests.rs | 113 ++++++++++++ prover/src/tests/mod.rs | 2 + 9 files changed, 681 insertions(+) create mode 100644 crypto/stark/src/constraint_ir/builder.rs create mode 100644 crypto/stark/src/constraint_ir/interp.rs create mode 100644 crypto/stark/src/constraint_ir/ir.rs create mode 100644 crypto/stark/src/constraint_ir/mod.rs create mode 100644 prover/src/tests/constraint_ir_tests.rs diff --git a/crypto/stark/src/constraint_ir/builder.rs b/crypto/stark/src/constraint_ir/builder.rs new file mode 100644 index 000000000..29328d9b2 --- /dev/null +++ b/crypto/stark/src/constraint_ir/builder.rs @@ -0,0 +1,203 @@ +//! Explicit-builder capture front-end (Plan B). +//! +//! Where the symbolic-field front-end (Plan A) records IR by running a +//! constraint's generic `evaluate` over recording field types, this front-end +//! builds the same [`ConstraintProgram`] through an explicit [`IrBuilder`]: +//! each constraint implements [`Capture`](super::Capture) and translates its +//! `evaluate` body into builder calls (`main`, `add`, `sub`, `mul`, ...). +//! +//! No fake field, no thread-local arena. The builder hash-conses every node on +//! `(Op, Dim)` and only emits leaves for columns the constraint actually reads, +//! so captured programs are minimal. + +use std::collections::HashMap; + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; + +use super::ir::{ConstraintProgram, Dim, Op}; + +/// A handle to a node in an [`IrBuilder`]: its arena id and result dimension. +/// +/// `Copy` so constraint bodies read like ordinary field arithmetic. +#[derive(Clone, Copy, Debug)] +pub struct Expr { + id: u32, + dim: Dim, +} + +impl Expr { + /// The node's result dimension. + pub fn dim(self) -> Dim { + self.dim + } +} + +/// Builds a [`ConstraintProgram`] from explicit node-construction calls. +/// +/// Nodes are appended in topological order (id `i` references only `< i`) and +/// hash-consed on `(Op, Dim)`, so structurally identical subexpressions share a +/// single id. Base-field constants are additionally deduplicated by value via +/// `const_cache`. Node id `0` is reserved for `Op::Const1(0)`, matching the +/// interpreter's convention and Plan A's arena. +pub struct IrBuilder { + nodes: Vec, + dims: Vec, + cse: HashMap<(Op, Dim), u32>, + const_cache: HashMap, + roots: Vec, +} + +impl Default for IrBuilder { + fn default() -> Self { + Self::new() + } +} + +impl IrBuilder { + /// Create a builder with the reserved `Op::Const1(0)` node at id 0. + pub fn new() -> Self { + let mut b = IrBuilder { + nodes: Vec::new(), + dims: Vec::new(), + cse: HashMap::new(), + const_cache: HashMap::new(), + roots: Vec::new(), + }; + // Reserve id 0 = Const1(0). `const_base(0)` will hash-cons to this. + let zero = b.push(Op::Const1(0), Dim::D1); + debug_assert_eq!(zero.id, 0); + b.const_cache.insert(0, 0); + b + } + + /// Append (or reuse) a node with the given op and result dimension. + fn push(&mut self, op: Op, dim: Dim) -> Expr { + if let Some(&id) = self.cse.get(&(op, dim)) { + return Expr { id, dim }; + } + let id = self.nodes.len() as u32; + self.nodes.push(op); + self.dims.push(dim); + self.cse.insert((op, dim), id); + Expr { id, dim } + } + + // --------------------------------------------------------------------- + // Leaves + // --------------------------------------------------------------------- + + /// A main-trace column read at the given frame `offset`, row 0. + pub fn main(&mut self, offset: u8, col: usize) -> Expr { + self.push( + Op::Var { + main: true, + offset, + row: 0, + col: col as u16, + }, + Dim::D1, + ) + } + + /// An aux-trace column read at the given frame `offset`, row 0 (`D3`). + pub fn aux(&mut self, offset: u8, col: usize) -> Expr { + self.push( + Op::Var { + main: false, + offset, + row: 0, + col: col as u16, + }, + Dim::D3, + ) + } + + // --------------------------------------------------------------------- + // Constants + // --------------------------------------------------------------------- + + /// A base-field constant from a `u64`, reduced and deduplicated by value. + pub fn const_base(&mut self, v: u64) -> Expr { + let canon = *FieldElement::::from(v).value(); + if let Some(&id) = self.const_cache.get(&canon) { + return Expr { id, dim: Dim::D1 }; + } + let e = self.push(Op::Const1(canon), Dim::D1); + self.const_cache.insert(canon, e.id); + e + } + + /// A base-field constant from an `i64`; negatives map to `p - |v|`. + pub fn const_signed(&mut self, v: i64) -> Expr { + let canon = *FieldElement::::from(v).value(); + if let Some(&id) = self.const_cache.get(&canon) { + return Expr { id, dim: Dim::D1 }; + } + let e = self.push(Op::Const1(canon), Dim::D1); + self.const_cache.insert(canon, e.id); + e + } + + /// The base-field constant `1`. + pub fn one(&mut self) -> Expr { + self.const_base(1) + } + + // --------------------------------------------------------------------- + // Arithmetic + // --------------------------------------------------------------------- + + /// `a + b`. Result is `D1` only if both operands are `D1`. + pub fn add(&mut self, a: Expr, b: Expr) -> Expr { + let dim = Self::join(a.dim, b.dim); + self.push(Op::Add(a.id, b.id), dim) + } + + /// `a - b`. Result is `D1` only if both operands are `D1`. + pub fn sub(&mut self, a: Expr, b: Expr) -> Expr { + let dim = Self::join(a.dim, b.dim); + self.push(Op::Sub(a.id, b.id), dim) + } + + /// `a * b`. Result is `D1` only if both operands are `D1`. + pub fn mul(&mut self, a: Expr, b: Expr) -> Expr { + let dim = Self::join(a.dim, b.dim); + self.push(Op::Mul(a.id, b.id), dim) + } + + /// `-a`. Preserves the operand's dimension. + pub fn neg(&mut self, a: Expr) -> Expr { + self.push(Op::Neg(a.id), a.dim) + } + + /// Typing join: `(D1, D1) -> D1`; any `D3` operand -> `D3`. + fn join(a: Dim, b: Dim) -> Dim { + match (a, b) { + (Dim::D1, Dim::D1) => Dim::D1, + _ => Dim::D3, + } + } + + // --------------------------------------------------------------------- + // Emit / finish + // --------------------------------------------------------------------- + + /// Record `e` as the root for constraint `constraint_idx`. + /// + /// Roots are stored in emit order; the minimal spike emits exactly one root + /// per program, so `constraint_idx` is accepted for parity with the + /// production design but not used to index `roots` here. + pub fn emit(&mut self, _constraint_idx: usize, e: Expr) { + self.roots.push(e.id); + } + + /// Consume the builder and produce the captured program. + pub fn finish(self) -> ConstraintProgram { + ConstraintProgram { + nodes: self.nodes, + dims: self.dims, + roots: self.roots, + } + } +} diff --git a/crypto/stark/src/constraint_ir/interp.rs b/crypto/stark/src/constraint_ir/interp.rs new file mode 100644 index 000000000..62e502594 --- /dev/null +++ b/crypto/stark/src/constraint_ir/interp.rs @@ -0,0 +1,97 @@ +//! CPU interpreter for a captured [`ConstraintProgram`]. +//! +//! A single forward pass over the topologically ordered nodes evaluates each +//! node into a [`Value`] (base `D1` or extension `D3`), reusing the real +//! `FieldElement` arithmetic so per-op results are bit-identical to the boxed +//! constraint path. Mixed-dimension ops auto-embed the `D1` operand into `D3`, +//! mirroring the field tower's `F: IsSubFieldOf` arithmetic. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField as GoldilocksExtension; +use math::field::goldilocks::GoldilocksField; + +use super::ir::{ConstraintProgram, Dim, Op}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +/// A node's computed value: base field (`D1`) or degree-3 extension (`D3`). +#[derive(Clone, Copy, Debug)] +enum Value { + D1(Fp), + D3(Fp3), +} + +impl Value { + /// Promote to the extension field, embedding a base value if needed. + fn to_ext(self) -> Fp3 { + match self { + Value::D1(x) => x.to_extension::(), + Value::D3(x) => x, + } + } + + fn as_base(self) -> Fp { + match self { + Value::D1(x) => x, + Value::D3(_) => { + panic!("expected a base (D1) value but found an extension (D3) value") + } + } + } +} + +/// Evaluate the program's single root over a base-field main row. +/// +/// `main_row[col]` resolves `Var { main: true, col, .. }` leaves. The minimal +/// algebraic constraint set only reads main columns at offset 0, row 0 and +/// returns a base-field (`D1`) value, so this returns a `FieldElement`. +pub fn eval_program_base(prog: &ConstraintProgram, main_row: &[Fp]) -> Fp { + let mut values: Vec = Vec::with_capacity(prog.nodes.len()); + + for (i, op) in prog.nodes.iter().enumerate() { + let v = match *op { + Op::Const1(c) => Value::D1(Fp::from(c)), + Op::Const3([c0, c1, c2]) => { + Value::D3(Fp3::from_raw([Fp::from(c0), Fp::from(c1), Fp::from(c2)])) + } + Op::Var { main, row, col, .. } => { + assert!(main, "aux leaves are not part of the minimal algebraic set"); + assert_eq!(row, 0, "minimal set reads row 0 only"); + Value::D1(main_row[col as usize]) + } + Op::Add(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x + y, |x, y| x + y), + Op::Sub(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x - y, |x, y| x - y), + Op::Mul(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x * y, |x, y| x * y), + Op::Neg(a) => match (values[a as usize], prog.dims[i]) { + (Value::D1(x), Dim::D1) => Value::D1(-x), + (val, Dim::D3) => Value::D3(-val.to_ext()), + (Value::D3(x), Dim::D1) => Value::D3(-x), // dim mismatch, keep ext + }, + Op::Embed(a) => Value::D3(values[a as usize].to_ext()), + }; + values.push(v); + } + + let root = prog.roots[0]; + values[root as usize].as_base() +} + +/// Apply a binary op, auto-embedding to the extension field when the result +/// dimension is `D3` (or either operand is already `D3`). +#[inline] +fn binop( + values: &[Value], + a: u32, + b: u32, + result_dim: Dim, + base_op: impl Fn(Fp, Fp) -> Fp, + ext_op: impl Fn(Fp3, Fp3) -> Fp3, +) -> Value { + let va = values[a as usize]; + let vb = values[b as usize]; + match (va, vb, result_dim) { + (Value::D1(x), Value::D1(y), Dim::D1) => Value::D1(base_op(x, y)), + _ => Value::D3(ext_op(va.to_ext(), vb.to_ext())), + } +} diff --git a/crypto/stark/src/constraint_ir/ir.rs b/crypto/stark/src/constraint_ir/ir.rs new file mode 100644 index 000000000..8d0a3c449 --- /dev/null +++ b/crypto/stark/src/constraint_ir/ir.rs @@ -0,0 +1,80 @@ +//! Flat intermediate representation (IR) for captured transition constraints. +//! +//! A [`ConstraintProgram`] is a topologically ordered list of [`Op`] nodes plus +//! a per-constraint root id. It is produced by the builder capture front-end +//! (see [`crate::constraint_ir::builder`]) and consumed by the CPU interpreter +//! (see [`crate::constraint_ir::interp`]). +//! +//! The IR is single-field over Goldilocks, with a [`Dim`] tag distinguishing +//! base (`D1`, one `u64`) from the degree-3 extension (`D3`, three `u64`). + +/// Field-arithmetic dimension of a node's value: base Goldilocks (`D1`) or its +/// degree-3 extension (`D3`). +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)] +pub enum Dim { + /// Base field (one Goldilocks `u64`). + #[default] + D1, + /// Degree-3 extension (`[u64; 3]`). + D3, +} + +/// One IR instruction. Operand fields are `u32` ids into the program's `nodes` +/// arena; a node with id `i` only references nodes with id `< i`. +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum Op { + /// A base-field literal (already reduced mod the Goldilocks prime). + Const1(u64), + /// An extension-field literal `[c0, c1, c2]` (each component reduced). + Const3([u64; 3]), + /// A leaf read of a main-trace cell. `main` is always `true` for the + /// minimal algebraic set captured by the spike; aux reads would set it + /// `false`. `offset`/`row` select the frame step/row, `col` the column. + Var { + /// `true` for a main-trace column read, `false` for an aux read. + main: bool, + /// Frame step index (0-based). + offset: u8, + /// Row within the step. + row: u8, + /// Column index. + col: u16, + }, + /// `nodes[a] + nodes[b]`. + Add(u32, u32), + /// `nodes[a] - nodes[b]`. + Sub(u32, u32), + /// `nodes[a] * nodes[b]`. + Mul(u32, u32), + /// `-nodes[a]`. + Neg(u32), + /// Embed a `D1` value into `D3` (`>::embed`). + Embed(u32), +} + +/// A captured program for one transition constraint (or a set of them). +/// +/// `nodes` is topologically ordered (id `i` references only `< i`). `dims[i]` +/// is the result dimension of `nodes[i]`. `roots[c]` is the node id of +/// constraint `c`'s value. +#[derive(Clone, Debug)] +pub struct ConstraintProgram { + /// Topologically ordered instruction list. + pub nodes: Vec, + /// Per-node result dimension, parallel to `nodes`. + pub dims: Vec, + /// Per-constraint root node ids. + pub roots: Vec, +} + +impl ConstraintProgram { + /// Number of nodes in the program (an effectiveness measure for hash-consing). + pub fn len(&self) -> usize { + self.nodes.len() + } + + /// Whether the program has no nodes. + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } +} diff --git a/crypto/stark/src/constraint_ir/mod.rs b/crypto/stark/src/constraint_ir/mod.rs new file mode 100644 index 000000000..a515ff177 --- /dev/null +++ b/crypto/stark/src/constraint_ir/mod.rs @@ -0,0 +1,39 @@ +//! Explicit-builder constraint capture spike (Plan B). +//! +//! Proof-of-concept that lambda_vm's algebraic transition constraints can be +//! captured into a flat, single-field Goldilocks IR via an explicit +//! [`IrBuilder`] (rather than the recording "symbolic field" of Plan A), and +//! that interpreting that IR on the CPU reproduces the constraint's real +//! `evaluate` bit-for-bit. +//! +//! Both plans produce the SAME IR and use the SAME interpreter; they differ +//! only in the capture front-end. Here each constraint implements [`Capture`] +//! and translates its `evaluate` body into builder calls. This is CPU-only and +//! does not touch the prover hot loop, the LogUp framework, or GPU code. +//! +//! - [`ir`]: the IR data structures ([`ConstraintProgram`], [`Op`], [`Dim`]). +//! - [`builder`]: the [`IrBuilder`] and [`Expr`] capture API. +//! - [`interp`]: a CPU forward-pass interpreter over the IR. +//! +//! [`ConstraintProgram`]: ir::ConstraintProgram +//! [`Op`]: ir::Op +//! [`Dim`]: ir::Dim + +pub mod builder; +pub mod interp; +pub mod ir; + +pub use builder::{Expr, IrBuilder}; +pub use interp::eval_program_base; +pub use ir::{ConstraintProgram, Dim, Op}; + +/// A transition constraint that can record its algebra into an [`IrBuilder`]. +/// +/// Object-safe: `capture` is non-generic (it takes `&mut IrBuilder`), so a +/// constraint can be captured behind a `&dyn Capture`, mirroring the production +/// design where the capture method is not generic over the field tower. +pub trait Capture { + /// Translate this constraint's algebra into builder nodes, finishing with a + /// single `b.emit(constraint_idx, root)` call. + fn capture(&self, b: &mut IrBuilder); +} diff --git a/crypto/stark/src/lib.rs b/crypto/stark/src/lib.rs index e9f6a1cda..5ec372c23 100644 --- a/crypto/stark/src/lib.rs +++ b/crypto/stark/src/lib.rs @@ -5,6 +5,7 @@ compile_error!("the `disk-spill` feature requires memmap2, which does not compil #[cfg(feature = "debug-checks")] pub mod bus_debug; +pub mod constraint_ir; pub mod constraints; pub mod context; pub mod debug; diff --git a/prover/src/constraints/cpu.rs b/prover/src/constraints/cpu.rs index facc9e16d..1c811471b 100644 --- a/prover/src/constraints/cpu.rs +++ b/prover/src/constraints/cpu.rs @@ -17,6 +17,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::table::TableView; @@ -112,6 +113,16 @@ impl TransitionConstraint for ProductZeroC } } +impl Capture for ProductZeroConstraint { + fn capture(&self, b: &mut IrBuilder) { + // col_a * col_b + let a = b.main(0, self.col_a); + let b_col = b.main(0, self.col_b); + let root = b.mul(a, b_col); + b.emit(self.constraint_idx, root); + } +} + /// `(1 - MEMORY - BRANCH) · read_register2 · imm[i] = 0`: when neither MEMORY nor /// BRANCH is set, the `arg2` multiplex needs at most one of `rv2`/`imm` nonzero. /// Decoding already guarantees this; a spec defense-in-depth assumption. diff --git a/prover/src/constraints/templates.rs b/prover/src/constraints/templates.rs index ef5b6c036..daf25ae6d 100644 --- a/prover/src/constraints/templates.rs +++ b/prover/src/constraints/templates.rs @@ -13,6 +13,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::{constraints::transition::TransitionConstraint, table::TableView}; use crate::tables::types::{GoldilocksExtension, GoldilocksField}; @@ -107,6 +108,28 @@ impl TransitionConstraint for IsBitConstra } } +impl Capture for IsBitConstraint { + fn capture(&self, b: &mut IrBuilder) { + // Mirrors `evaluate`: x = main(value_col), one - x, then the product. + let x = b.main(0, self.value_col); + let one = b.one(); + let one_minus_x = b.sub(one, x); + + let root = match self.cond_col { + Some(cond_col) => { + // cond * x * (1 - x), left-associated like `&cond * &x * (one - x)`. + let cond = b.main(0, cond_col); + let cond_x = b.mul(cond, x); + b.mul(cond_x, one_minus_x) + } + // x * (1 - x) + None => b.mul(x, one_minus_x), + }; + + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // ADD Template (Embedded Carry Approach) // ========================================================================= @@ -177,6 +200,22 @@ impl AddLinearTerm { AddLinearTerm::Constant(value) => FieldElement::::from(*value), } } + + /// Capture this term into builder nodes, mirroring [`Self::eval`]. + fn capture(&self, b: &mut IrBuilder) -> Expr { + match self { + AddLinearTerm::Column { + coefficient, + column, + } => { + // `col_val * FieldElement::from(coeff)`: column on the left. + let col = b.main(0, *column); + let coeff = b.const_signed(*coefficient); + b.mul(col, coeff) + } + AddLinearTerm::Constant(value) => b.const_signed(*value), + } + } } /// Evaluate a slice of terms as a sum. @@ -195,6 +234,24 @@ where } } +/// Capture a slice of terms as a sum, mirroring [`eval_terms`]. +/// +/// Empty -> `0`; otherwise `0 + t0 + t1 + ...` (same fold seed and order as +/// `eval_terms`, so the captured node tree matches bit-for-bit). +fn capture_terms(terms: &[AddLinearTerm], b: &mut IrBuilder) -> Expr { + let zero = b.const_base(0); + if terms.is_empty() { + zero + } else { + let mut acc = zero; + for t in terms { + let term = t.capture(b); + acc = b.add(acc, term); + } + acc + } +} + impl AddOperand { /// Get the low word value from the trace. pub fn eval_lo(&self, step: &TableView) -> FieldElement @@ -224,6 +281,22 @@ impl AddOperand { } } + /// Capture the low word, mirroring [`Self::eval_lo`]. + pub fn capture_lo(&self, b: &mut IrBuilder) -> Expr { + match self { + AddOperand::DWordWL { start_column } => b.main(0, *start_column), + AddOperand::Linear { lo, .. } => capture_terms(lo, b), + } + } + + /// Capture the high word, mirroring [`Self::eval_hi`]. + pub fn capture_hi(&self, b: &mut IrBuilder) -> Expr { + match self { + AddOperand::DWordWL { start_column } => b.main(0, *start_column + 1), + AddOperand::Linear { hi, .. } => capture_terms(hi, b), + } + } + // ------------------------------------------------------------------------- // Convenience constructors for common cast types // ------------------------------------------------------------------------- @@ -485,6 +558,68 @@ impl TransitionConstraint for AddConstrain } } +impl AddConstraint { + /// Capture carry_0, mirroring [`Self::compute_carry_0`]. + fn capture_carry_0(&self, b: &mut IrBuilder) -> Expr { + let lhs_lo = self.lhs.capture_lo(b); + let rhs_lo = self.rhs.capture_lo(b); + let sum_lo = self.sum.capture_lo(b); + let inv = b.const_base(INV_SHIFT_32); + + // ((lhs_lo + rhs_lo) - sum_lo) * inv_2_32 + let s = b.add(lhs_lo, rhs_lo); + let s = b.sub(s, sum_lo); + b.mul(s, inv) + } + + /// Capture carry_1, mirroring [`Self::compute_carry_1`]. + fn capture_carry_1(&self, b: &mut IrBuilder) -> Expr { + let lhs_hi = self.lhs.capture_hi(b); + let rhs_hi = self.rhs.capture_hi(b); + let sum_hi = self.sum.capture_hi(b); + let carry_0 = self.capture_carry_0(b); + let inv = b.const_base(INV_SHIFT_32); + + // (((lhs_hi + rhs_hi) + carry_0) - sum_hi) * inv_2_32 + let s = b.add(lhs_hi, rhs_hi); + let s = b.add(s, carry_0); + let s = b.sub(s, sum_hi); + b.mul(s, inv) + } +} + +impl Capture for AddConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + + let carry = match self.carry_idx { + 0 => self.capture_carry_0(b), + 1 => self.capture_carry_1(b), + _ => unreachable!("carry_idx validated <= 1 at construction"), + }; + + let root = if self.cond_cols.is_empty() { + // Unconditional: carry * (1 - carry) + let one_minus_carry = b.sub(one, carry); + b.mul(carry, one_minus_carry) + } else { + // Conditional: cond * carry * (1 - carry), left-associated like + // `cond * &carry * (one - carry)`. + // cond = fold over cond_cols starting from zero: 0 + col0 + col1 + ... + let mut cond = b.const_base(0); + for &col in &self.cond_cols { + let c = b.main(0, col); + cond = b.add(cond, c); + } + let one_minus_carry = b.sub(one, carry); + let cond_carry = b.mul(cond, carry); + b.mul(cond_carry, one_minus_carry) + }; + + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // Helper Functions // ========================================================================= diff --git a/prover/src/tests/constraint_ir_tests.rs b/prover/src/tests/constraint_ir_tests.rs new file mode 100644 index 000000000..86bf51d81 --- /dev/null +++ b/prover/src/tests/constraint_ir_tests.rs @@ -0,0 +1,113 @@ +//! Differential tests for the explicit-builder constraint capture spike (Plan B). +//! +//! For each algebraic transition constraint, capture it into a flat IR via its +//! `Capture::capture` method (an explicit `IrBuilder`), then assert that +//! interpreting the IR reproduces the constraint's real +//! `evaluate::` bit-for-bit over many +//! random main rows. + +use crate::constraints::cpu::ProductZeroConstraint; +use crate::constraints::templates::{AddConstraint, AddOperand, IsBitConstraint}; +use crate::tables::types::{FE, GoldilocksExtension, GoldilocksField}; + +use math::field::element::FieldElement; +use stark::constraint_ir::{Capture, IrBuilder, eval_program_base}; +use stark::constraints::transition::TransitionConstraint; +use stark::table::TableView; + +/// Number of random trials per constraint. +const TRIALS: usize = 1000; + +/// Column count for the random frame; larger than any column index read by the +/// constraints under test (CPU columns go up to 37). +const NUM_COLS: usize = 64; + +/// A tiny deterministic SplitMix64 PRNG so the test needs no `rand` dependency +/// and is fully reproducible. +struct SplitMix64 { + state: u64, +} + +impl SplitMix64 { + fn new(seed: u64) -> Self { + Self { state: seed } + } + + fn next_u64(&mut self) -> u64 { + self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = self.state; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) + } +} + +/// Run the differential check: capture `c` via the builder, then for `TRIALS` +/// random rows compare the real `evaluate` against the IR interpreter, +/// bit-for-bit. +fn assert_ir_matches_evaluate(c: &T, label: &str) +where + T: TransitionConstraint + Capture, +{ + let mut b = IrBuilder::new(); + c.capture(&mut b); + let prog = b.finish(); + eprintln!("[{label}] captured {} IR nodes", prog.len()); + + let mut rng = SplitMix64::new(0xDEAD_BEEF_CAFE_F00D ^ (label.len() as u64)); + + for trial in 0..TRIALS { + // Build a random main row. + let row: Vec = (0..NUM_COLS).map(|_| FE::from(rng.next_u64())).collect(); + + // Real evaluate: wrap the row in a base/ext TableView (1 row, no aux). + let real_step: TableView = + TableView::new(vec![row.clone()], vec![Vec::new()]); + let real: FieldElement = + c.evaluate::(&real_step); + + // IR interpreter over the same row. + let got = eval_program_base(&prog, &row); + + assert_eq!( + real, got, + "[{label}] mismatch at trial {trial}: real={real:?} got={got:?}" + ); + } +} + +#[test] +fn test_ir_matches_is_bit_unconditional() { + // X * (1 - X), X at column 7. + let c = IsBitConstraint::unconditional(7, 0); + assert_ir_matches_evaluate(&c, "is_bit_unconditional"); +} + +#[test] +fn test_ir_matches_is_bit_conditional() { + // cond * X * (1 - X), cond at column 3, X at column 5. + let c = IsBitConstraint::new(3, 5, 0); + assert_ir_matches_evaluate(&c, "is_bit_conditional"); +} + +#[test] +fn test_ir_matches_add_constraint_carries() { + // 64-bit ADD with embedded carries, DWordWL operands. + // cond at col 0; lhs=[1,2], rhs=[3,4], sum=[5,6]. + let (carry0, carry1) = AddConstraint::new_pair( + vec![0], + AddOperand::dword(1), + AddOperand::dword(3), + AddOperand::dword(5), + 0, + ); + assert_ir_matches_evaluate(&carry0, "add_carry_0"); + assert_ir_matches_evaluate(&carry1, "add_carry_1"); +} + +#[test] +fn test_ir_matches_product_zero() { + // col_a * col_b, columns 12 and 17. + let c = ProductZeroConstraint::new(12, 17, 0); + assert_ir_matches_evaluate(&c, "product_zero"); +} diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 4d0ac4477..d7c5824ea 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -15,6 +15,8 @@ pub mod commit_tests; #[cfg(test)] pub mod compute_commit_bus_offset_tests; #[cfg(test)] +pub mod constraint_ir_tests; +#[cfg(test)] pub mod constraints_tests; #[cfg(all(test, feature = "disk-spill"))] pub mod count_table_lengths_drift_tests; From 004230c65d1beec18246410192ca2d1ba7a1a90f Mon Sep 17 00:00:00 2001 From: MauroFab Date: Mon, 29 Jun 2026 17:42:38 -0300 Subject: [PATCH 2/6] docs(gpu-constraint-eval): add design docs + Plan-B roadmap to the production branch --- thoughts/gpu-constraint-eval/README.md | 96 ++ .../plan-builder-rewrite.md | 940 ++++++++++++++++++ .../plan-symbolic-field.md | 893 +++++++++++++++++ thoughts/gpu-constraint-eval/roadmap.md | 164 +++ 4 files changed, 2093 insertions(+) create mode 100644 thoughts/gpu-constraint-eval/README.md create mode 100644 thoughts/gpu-constraint-eval/plan-builder-rewrite.md create mode 100644 thoughts/gpu-constraint-eval/plan-symbolic-field.md create mode 100644 thoughts/gpu-constraint-eval/roadmap.md diff --git a/thoughts/gpu-constraint-eval/README.md b/thoughts/gpu-constraint-eval/README.md new file mode 100644 index 000000000..bfefac626 --- /dev/null +++ b/thoughts/gpu-constraint-eval/README.md @@ -0,0 +1,96 @@ +# GPU-ready constraint evaluation — design + +Design docs for moving STARK transition-constraint evaluation onto the GPU. + +## Why (the real motivation) + +Not to make constraint evaluation faster — it isn't the CPU bottleneck. The goal is +**data residency**: keep the whole prove pipeline on-device so we never round-trip the +LDE trace across the PCIe bus. + +``` +LDE (GPU) → constraint eval / composition poly → Merkle commit (GPU) → FRI (GPU) +``` + +The LDE trace (main + aux columns × blowup factor) is the largest array in the pipeline. +If constraint eval stays on the CPU, every proof must D2H-copy the entire LDE and push +results back — that transfer dominates. `gpu_lde.rs` already keeps columns resident +(`GpuLdeBase`/`GpuLdeExt3` keep-handles); on-GPU constraint eval consumes them in place. + +Consequence: the GPU kernel must also do the accumulation (`Σ αⁱ·Cᵢ·Zᵢ⁻¹`, Horner form ++ ÷Z) so it emits **composition-polynomial evaluations on-device**, not a raw `Cᵢ` +matrix (which would itself be a large D2H copy). + +## The blocker + +The prover is Rust; you can't run arbitrary Rust on a GPU. Constraints today are +`Vec>` evaluated via a generic +`evaluate` — two layers of dynamic dispatch, scalar, CPU-only. The logic has to +be re-expressed in a GPU-executable form. + +## The decided architecture + +**Capture each table's constraints once into a flat, single-field Goldilocks IR (a +typed `Dim1`/`Dim3` op-DAG), then interpret that IR** — on CPU (verifier/optional +prover) and on GPU (one universal Goldilocks kernel). Single source of truth → CPU and +GPU can't diverge. Not codegen, not a DSL, not hand-written per-table kernels. + +- Field: Goldilocks base (`Dim1`, `u64`) + degree-3 extension (`Dim3`, `[u64;3]`). +- IR ops: `Add/Sub/Mul/Neg` + leaves (`Main/Aux{offset,col}`, `Const`, `Periodic`, + `RapChallenge`, `AlphaPow`, `TableOffset`, `Shift`). +- Boundary: the zerofier/coefficient machinery stays in + `ConstraintEvaluator::evaluate_transitions`; the IR replaces only the per-row, + per-constraint step that produces each `Cᵢ` (on GPU, fused with the accumulation). + +This is the same family SP1, OpenVM, and zisk converged on. zisk is the closest match +(Goldilocks, FRI-STARK LDE quotient). + +## The two plans (the only open decision) + +Both produce the **same IR** and feed the **same interpreter + GPU kernel + validation**. +They differ *only* in how the IR is captured. + +| | [Plan A — symbolic field](./plan-symbolic-field.md) | [Plan B — builder rewrite](./plan-builder-rewrite.md) | +|---|---|---| +| Constraint edits | ~0 (record existing `evaluate` via a `SymField`) | ~600–800 LOC across 33 structs rewritten to `capture()` | +| Feasibility | HIGH — `SymField` needs only `IsField`+`IsSubFieldOf` (capture never builds an `AIR`); unreachable methods stubbed | No doubt; just labor | +| Risk shape | Concentrated in `SymField` — spike-able in 1–2 days | Spread across 33 transcriptions (Dvrm 11 / Cpu32 8 / Shift 7 kinds) | +| CPU path | can stay unchanged (IR GPU-only) | forced onto the interpreter (old `evaluate` deleted) | +| End state | recording field + arena + stubs retained | cleanest; generic `evaluate`+adapter deleted; ecosystem-idiomatic | +| Effort (CPU validated) | ~10–14 d | ~12–18 d | +| Effort (GPU) | ~6–10 d | ~5–7 d (identical, shared) | + +Both lose the per-row LogUp zero-skip (value-identical; recover via a static +const-fold peephole). Neither AVX nor monomorphization differentiates them (AVX lives +in the shared interpreter; monomorphization is a third thing neither plan does). + +## Reference implementation + +`others/openvm-stark-backend` (cloned `openvm-org/stark-backend@v1.4.0`) is a working +implementation of this exact approach for a FRI-STARK LDE quotient. Key files: + +- `crates/cuda-backend/src/transpiler/mod.rs` — lowers the symbolic DAG to three-address + code + liveness/linear-scan register allocation (the **IR processing** — the most + portable, field-agnostic piece; ~200 lines; the reg-alloc is optional for v1). +- `crates/cuda-backend/src/transpiler/codec.rs` + `cuda/include/codec.cuh` — encode rules + to a 128-bit packed word. +- `crates/cuda-backend/cuda/src/quotient.cu` — the interpreter kernel: per-row loop over + rules, fused Horner accumulation + ÷Z, per-thread intermediate buffer (local for small + programs, global spill for large — solves GPU scratch pressure). + +BabyBear→Goldilocks deltas to be aware of: the codec packs constants in 32 bits (fits +BabyBear's 31-bit modulus, **not** Goldilocks' 64 bits) → needs a side constant table; +extension is BabyBear's vs our degree-3; OpenVM evaluates everything in `FpExt` (no +base/ext split). It's a blueprint to port, not a crate to depend on (it's tied to +OpenVM's symbolic-DAG type, `PrimeField32` bound, and trace/bus conventions). + +SP1's `sp1-gpu` is the same pattern via an SSA register-machine bytecode (~60+ opcodes, +operand types in the opcode); OpenVM puts operand types in the source tag (~6 ops) — +the latter is the better template for single-field Goldilocks (even fewer source types). + +## Recommendation / next step + +Spike **Plan A** first (1–2 days): implement `SymField`, capture the CPU table, diff the +interpreted IR bit-for-bit against the current evaluator, and dump per-table node counts. +The IR/interpreter/GPU kernel are shared, so switching to Plan B later costs almost +nothing. If `SymField` fights the trait tower, fall back to the Plan B rewrite. diff --git a/thoughts/gpu-constraint-eval/plan-builder-rewrite.md b/thoughts/gpu-constraint-eval/plan-builder-rewrite.md new file mode 100644 index 000000000..e6e3d2bd2 --- /dev/null +++ b/thoughts/gpu-constraint-eval/plan-builder-rewrite.md @@ -0,0 +1,940 @@ +# Plan: Capture-Method Rewrite ("Change the constraints") for GPU-ready STARK constraint evaluation + +> Approach: rewrite each transition constraint so it *emits* its polynomial into a +> builder/capture abstraction once at setup, producing a flat single-field +> Goldilocks IR, then interpret that IR on CPU (prover over the LDE coset; verifier +> at the OOD point) and later on GPU. This is the head-to-head sibling of the +> "wrap the field type / shadow `IsField`" approach. + +All file/line references below were read and verified against the working tree +(branch `main`) unless explicitly marked `? INFERRED` or `✗ UNCERTAIN`. + +--- + +## 1. Overview & end-state + +After this change, every table's transition constraints are *captured once* into a +flat per-table IR program (`TableProgram`) at AIR-construction time. The +per-row/per-OOD hot path no longer dispatches through +`Vec>` calling a generic `evaluate`; +instead `air.compute_transition_prover` (prover) and `air.compute_transition` +(verifier) call a single **interpreter** that walks the IR against the current +`Frame`/`TableView`, writing each constraint's scalar `Cᵢ` into the existing +`base_evals`/`ext_evals` buffers. The accumulation `Σ βᵢ·Cᵢ·Zᵢ⁻¹ + boundary`, +zerofiers, and `ZerofierEvaluations` machinery in +`ConstraintEvaluator::evaluate_transitions` are **untouched** — the IR/interpreter +only replaces the step that produces each `Cᵢ`. Because the IR is a flat array of +Goldilocks-typed ops, the same bytes feed a single universal Goldilocks +interpreter CUDA kernel, dispatched through the existing TypeId+transmute seam used +by `gpu_lde.rs`. + +``` + ┌─────────────── setup (once per AIR) ──────────────┐ + constraint structs ──► capture(&mut IrBuilder) ──► TableProgram (flat IR) │ + (IsBit, Add, Mul…) (column reads/+/-/* → IR nodes) { ops, consts, │ + emits, n_dim1 } │ + └───────────────────────────────────────────────────┘ + │ stored in the AIR + ┌───────────────────────────┼───────────────────────────────┐ + ▼ prover, per LDE row ▼ verifier, at OOD z ▼ GPU + interpret(program, frame_prover) interpret(program, frame_verifier) cuda kernel(program, lde) + → base_evals[], ext_evals[] → ext_evals[] → Cᵢ per row, per table + │ │ │ + └────── feeds unchanged ───────────┴── ConstraintEvaluator ─────┘ + Σ βᵢ·Cᵢ·Zᵢ⁻¹ + boundary (evaluator.rs:102-134, untouched) +``` + +--- + +## 2. The IR (concrete Rust data structures) + +The IR is **single-field over Goldilocks** with explicit base (`dim1`) vs cubic-ext +(`dim3`) typing on each node, so the interpreter knows the storage width and which +arithmetic routine to use. It lives in a new module `crypto/stark/src/ir.rs`. + +### 2.1 Node typing + +`Dim` records the field width of a value. Goldilocks base = `Dim1` (`[u64;1]`), +`Degree3GoldilocksExtensionField` = `Dim3` (`[u64;3]`). (Verified: `IsField for +Degree3GoldilocksExtensionField { type BaseType = [FpE;3] }`, +`extensions_goldilocks.rs:277`; base = `repr(transparent)` `u64`.) + +```rust +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum Dim { D1, D3 } + +/// Index into the program's node arena. Nodes are in topological (emission) order, +/// so an interpreter can evaluate left-to-right into a value stack/arena. +pub type NodeId = u32; +``` + +### 2.2 Op / node enum + +```rust +#[derive(Clone, Copy, Debug)] +pub enum Op { + // ---- leaves (inputs) ---- + /// main trace column read: step offset (0 = this row, 1 = next row), column idx. + Main { offset: u8, col: u16 }, // always Dim1 on prover; Dim3 on verifier (see §6) + /// aux trace column read. + Aux { offset: u8, col: u16 }, // Dim3 (aux is extension-valued) + /// constant base-field element, index into `consts_d1`. + ConstD1 { k: u32 }, // Dim1 + /// constant ext element, index into `consts_d3`. + ConstD3 { k: u32 }, // Dim3 (rarely needed for algebraic; see §5) + /// periodic column j at this row (uniform per-row input). + Periodic { j: u16 }, // Dim1 + /// LogUp challenge i (z=0, alpha=1 by convention), uniform per-proof. + Challenge { i: u16 }, // Dim3 + /// alpha power k (precomputed Σ over the proof), uniform per-proof. + AlphaPow { k: u16 }, // Dim3 + /// logup_table_offset uniform per-proof. + TableOffset, // Dim3 + /// packing shift constant (8,16,24) — small base consts, can also be ConstD1. + // (shifts are just ConstD1 entries; no dedicated op needed.) + + // ---- arithmetic (operands are NodeIds already emitted) ---- + Add { a: NodeId, b: NodeId }, + Sub { a: NodeId, b: NodeId }, + Mul { a: NodeId, b: NodeId }, + Neg { a: NodeId }, +} + +#[derive(Clone, Copy, Debug)] +pub struct Node { pub op: Op, pub dim: Dim } +``` + +The interpreter's typing rule for `Add/Sub/Mul`: `dim = max(dim(a), dim(b))` +(D3 > D1). A `Mul` of `D1×D3` is the cheap subfield mul (componentwise scalar, +`GoldilocksField: IsSubFieldOf::mul` = `[a*b0, a*b1, a*b2]`, verified +`extensions_goldilocks.rs:413`); `D3×D3` is the full cubic mul +`c0=a0b0+2(a1b2+a2b1)`, … (verified `extensions_goldilocks.rs:298-306`); `D1×D1` +is a plain `GoldilocksField::mul`. This single rule subsumes every "mixed base×ext" +case the current code handles via `IsSubFieldOf`. + +### 2.3 Per-table program + +```rust +pub struct TableProgram { + pub nodes: Vec, // topological arena + pub consts_d1: Vec, // deduplicated base constants + pub consts_d3: Vec<[u64; 3]>, // deduplicated ext constants (usually empty) + /// emit[c] = NodeId of constraint c's root value. Length = num_transition_constraints. + pub emits: Vec, + /// First `num_base` constraints are D1 (base-field), matching + /// `num_base_transition_constraints()`. Used to split base/ext eval buffers. + pub num_base: usize, + /// metadata for input plumbing / GPU upload + pub num_main_cols: u16, + pub num_aux_cols: u16, + pub max_offset: u8, // 1 (next-row) for LogUp accumulated; else 0 +} +``` + +### 2.4 Serialization for GPU + +`nodes` is `Vec`; `Node`/`Op` are `Copy` plain-old-data. For GPU we lower +`Op` to a fixed-width tagged record `struct GpuOp { tag: u32, dim: u32, a: u32, b: u32 }` +(16 bytes) — leaves pack their immediates into `a`/`b` (e.g. `Main`: a=offset, +b=col). `consts_d1`/`consts_d3` upload as `&[u64]`. This is a flat `Vec`/`Vec` +blob: H2D once per table at setup, reused for every LDE row (the kernel runs the +program per row). No per-row host work crosses the boundary — only the device- +resident LDE columns (already kept on device by R1, see `gpu_lde.rs` `gpu_main()`/ +`gpu_aux()` handles) plus the uniforms (challenges, alpha powers, periodic, table +offset). `Op`'s representation is internal; we do **not** need `serde` on it unless +we want to cache programs to disk (out of scope). + +--- + +## 3. Capture front-end — builder/capture API & object-safety (distinguishing section) + +### 3.1 Object-safety decision (Question 1) — **RECOMMENDATION: non-generic `capture(&self, &mut IrBuilder)`** + +The constraints are stored heterogeneously as +`Vec>>` +(verified `traits.rs:316`, `lookup.rs:813`). A method generic over a builder type +`fn capture(&self, &mut AB)` is **NOT object-safe** (generic methods +can't go through a vtable), so it could not be called on `Box`. Two ways out: + +- **(a) non-generic `capture(&self, builder: &mut IrBuilder)` with a CONCRETE + builder.** Object-safe. Runs **once at setup** (not in the hot loop), so the + concrete builder costs nothing at steady state. The builder is a struct, not a + trait. This is the minimal, lowest-risk change: the existing + `TransitionConstraint` trait gains one object-safe method. +- (b) builder-generic `eval` (Plonky3/SP1 `AirBuilder` style). To call it + through a boxed trait object you must either (i) monomorphize per concrete AIR + (de-box: store constraints as a concrete `enum`/typed vec per table, a much + bigger refactor touching every table's assembly fn and `AirWithBuses`), or + (ii) add a non-generic shim per constraint anyway (which is just (a) again). + +**Recommendation: (a).** Reasoning: +1. It is object-safe, so it drops straight into the existing + `Box` storage with zero changes to how tables assemble constraints + (`create_all_cpu_constraints`, `mul_constraints`, `dvrm_constraints`, …). +2. Capture is a one-time setup cost; there is no monomorphization win to be had at + runtime because the runtime work is the interpreter, not the constraint body. +3. The interpreter is the single hot path; we want exactly one concrete builder so + the IR is canonical and identical for CPU and GPU. A generic builder would let + callers instantiate it with an "eval-directly" builder, re-introducing the very + `IsField` trait-tower fight this approach exists to avoid. + +The one real cost of (a): the builder is monomorphic on Goldilocks, so a +constraint can't be captured for a non-Goldilocks field. That is exactly the +project's constraint (base = Goldilocks, ext = degree-3 Goldilocks), so it's a +non-issue here. The generic `evaluate` is retained transitionally for +migration/validation (see §6, §9) and deleted at the end. + +### 3.2 The `IrBuilder` surface (Question 2) + +```rust +pub struct IrBuilder { + nodes: Vec, + consts_d1: Vec, + consts_d3: Vec<[u64; 3]>, + emits: Vec, // indexed by constraint_idx + const_d1_cache: HashMap, // dedupe constants + const_d3_cache: HashMap<[u64;3], u32>, + num_main_cols: u16, + num_aux_cols: u16, + max_offset: u8, + // CSE cache: (Op canonicalized) -> NodeId, to coalesce repeated subexpressions + // (e.g. `one`, `1 - x`, shift consts). Optional but cheap and shrinks the IR a lot. + cse: HashMap, +} + +/// Typed handle so `+ - *` compose with compile-time dim tracking and a tiny op set. +#[derive(Clone, Copy)] +pub struct Expr { id: NodeId, dim: Dim } + +impl IrBuilder { + // ---- leaves ---- + pub fn main(&mut self, offset: u8, col: usize) -> Expr; // Dim1 + pub fn aux(&mut self, offset: u8, col: usize) -> Expr; // Dim3 + pub fn const_base(&mut self, v: u64) -> Expr; // Dim1 (dedup) + pub fn const_signed(&mut self, v: i64) -> Expr; // Dim1, maps i64→field + pub fn const_ext(&mut self, v: [u64;3]) -> Expr; // Dim3 (dedup) + pub fn one(&mut self) -> Expr; // = const_base(1) + pub fn periodic(&mut self, j: usize) -> Expr; // Dim1 + pub fn challenge(&mut self, i: usize) -> Expr; // Dim3 + pub fn alpha_power(&mut self, k: usize) -> Expr; // Dim3 + pub fn table_offset(&mut self) -> Expr; // Dim3 + pub fn bus_id(&mut self, id: u64) -> Expr; // = const_base(id) (α⁰ term) + + // ---- arithmetic (auto dim = max) ---- + pub fn add(&mut self, a: Expr, b: Expr) -> Expr; + pub fn sub(&mut self, a: Expr, b: Expr) -> Expr; + pub fn mul(&mut self, a: Expr, b: Expr) -> Expr; + pub fn neg(&mut self, a: Expr) -> Expr; + + // ---- output ---- + /// Record that constraint `constraint_idx` evaluates to `e`. + pub fn emit(&mut self, constraint_idx: usize, e: Expr); + + pub fn finish(self) -> TableProgram; +} +``` + +Notes on the surface vs the prompt's sketch: +- No `table_offset()` for periodic *exemption offsets* — those stay in the + zerofier machinery (`transition.rs:160`), which is outside the boundary. +- `Expr` carries `dim`, so `mul(d1, d3)` is legal and lowers to the cheap subfield + mul; `Expr` makes the constraint bodies read almost identically to today. +- CSE + constant dedup are pure size optimizations; correctness doesn't depend on + them. (`one`, `shift_16`, `INV_SHIFT_32` recur across most bodies.) + +### 3.3 Trait change + +`TransitionConstraint` (`transition.rs:332`) gains: + +```rust +/// Emit this constraint's polynomial into the builder. Called once at setup. +/// `builder.emit(self.constraint_idx(), root)` records the result. +fn capture(&self, builder: &mut IrBuilder); +``` + +`TransitionConstraintEvaluator` (`transition.rs:10`, object-safe) gains a forwarding +non-generic method: + +```rust +fn capture(&self, builder: &mut IrBuilder); +``` + +The adapter `TransitionConstraintAdapter` (`transition.rs:395`) forwards +`capture` to `self.0.capture(builder)`. During migration the adapter keeps its +`evaluate_verifier`/`evaluate_prover` too (used by the parallel old path for +bit-for-bit validation, §12). + +--- + +## 4. Rewriting the algebraic constraints (Question 3 + full scope) + +### 4.1 Before/after: `IsBitConstraint` (`templates.rs:80-108`) + +**Before** (`evaluate`): +```rust +let x = step.get_main_evaluation_element(0, self.value_col).clone(); +let one = FieldElement::::one(); +match self.cond_col { + Some(cond_col) => { let cond = step.get_main_evaluation_element(0, cond_col).clone(); + &cond * &x * (one - x) } + None => &x * (one - &x), +} +``` +**After** (`capture`): +```rust +fn capture(&self, b: &mut IrBuilder) { + let x = b.main(0, self.value_col); + let one = b.one(); + let omx = b.sub(one, x); + let root = match self.cond_col { + Some(c) => { let cond = b.main(0, c); let xm = b.mul(x, omx); b.mul(cond, xm) } + None => b.mul(x, omx), + }; + b.emit(self.constraint_idx, root); +} +``` + +### 4.2 Before/after: `AddConstraint` — `AddOperand`/`AddLinearTerm` mapping (`templates.rs:359-486`) + +The lo/hi-limb abstraction with i64 coefficients maps cleanly. `AddLinearTerm::eval` +(`templates.rs:164`) becomes `capture`: +```rust +impl AddLinearTerm { + fn capture(&self, b: &mut IrBuilder) -> Expr { + match self { + AddLinearTerm::Column { coefficient, column } => { + let col = b.main(0, *column); + let k = b.const_signed(*coefficient); // i64 → field, was FieldElement::from(*coefficient) + b.mul(col, k) + } + AddLinearTerm::Constant(v) => b.const_signed(*v), + } + } +} +fn eval_terms_capture(terms: &[AddLinearTerm], b: &mut IrBuilder) -> Expr { + // empty → zero + let mut acc = b.const_base(0); + for t in terms { let e = t.capture(b); acc = b.add(acc, e); } + acc +} +``` +`AddOperand::eval_lo/eval_hi` → `capture_lo/capture_hi` (DWordWL reads +`main(0,start)` / `main(0,start+1)`; Linear → `eval_terms_capture`). Then +`compute_carry_0` (`templates.rs:414`): +```rust +// carry_0 = (lhs_lo + rhs_lo - sum_lo) * 2^(-32) +let inv = b.const_base(INV_SHIFT_32); // templates.rs:30, precomputed 2^-32 +let s = b.sub(b.add(lhs_lo, rhs_lo), sum_lo); +let c0 = b.mul(s, inv); +``` +`compute_carry_1` adds `carry_0` then the same `*inv`. `compute` then folds the cond +columns (`fold(zero, +)` → chain of `add`) and emits +`cond * carry * (one - carry)` (or unconditional). **The i64 coefficients are the +only subtlety** and they vanish because `const_signed(i64)` reproduces +`FieldElement::::from(i64)` exactly (the field's `From` already canonicalizes +negatives mod p). The lo/hi limb logic is pure compile-time structure; the captured +IR is a flat add/mul chain identical in value to the current `evaluate`. + +### 4.3 Before/after: `ProductZeroConstraint` (`cpu.rs:96-113`) + +**Before:** `step.get_main(0,col_a) * step.get_main(0,col_b)`. +**After:** +```rust +fn capture(&self, b: &mut IrBuilder) { + let a = b.main(0, self.col_a); let c = b.main(0, self.col_b); + let r = b.mul(a, c); b.emit(self.constraint_idx, r); +} +``` + +### 4.4 Before/after: a more complex algebraic constraint — `MulConstraint::RawProduct` (`mul.rs:766-844`) + +This is the representative "mega-constraint": a `kind` enum dispatched in +`compute()` (`mul.rs:721`), with a convolution body whose `for k` / `for j` loops +are bounded by compile-time `i` (not data). Capturing it **runs those loops once**, +unrolling them into a flat IR chain: + +```rust +// raw_product[i] - Σ_k 2^(16k) Σ_j lhs_ext[j]·rhs_ext[idx-j] +fn capture_raw_product(&self, i: usize, b: &mut IrBuilder) -> Expr { + let lhs = [cols::LHS_0, cols::LHS_1, cols::LHS_2, cols::LHS_3].map(|c| b.main(0, c)); + let rhs = [cols::RHS_0, cols::RHS_1, cols::RHS_2, cols::RHS_3].map(|c| b.main(0, c)); + let ln = b.main(0, cols::LHS_IS_NEGATIVE); + let rn = b.main(0, cols::RHS_IS_NEGATIVE); + let sf = b.const_base(SIGN_FILL); + let mut lhs_ext = [b.const_base(0); 8]; + let mut rhs_ext = [b.const_base(0); 8]; + lhs_ext[..4].copy_from_slice(&lhs); rhs_ext[..4].copy_from_slice(&rhs); + for j in 4..8 { lhs_ext[j] = b.mul(sf, ln); rhs_ext[j] = b.mul(sf, rn); } + let shift_16 = b.const_base(SHIFT_16); + let mut sum = b.const_base(0); + for k in 0..=1usize { + let idx = 2*i + k; + if idx < 8 { + let mut inner = b.const_base(0); + for j in 0..=idx { if j < 8 && idx-j < 8 { + inner = b.add(inner, b.mul(lhs_ext[j], rhs_ext[idx-j])); } } + sum = if k==0 { b.add(sum, inner) } else { b.add(sum, b.mul(inner, shift_16)) }; + } + } + let raw = b.main(0, raw_col_for(i)); + b.sub(raw, sum) +} +``` +**This is the central churn-reducing insight: no algebraic body has data-dependent +control flow.** Every loop bound, conditional, and column index is a function of +`self` only. So `capture` is a *mechanical mirror* of the existing body: swap +`FieldElement` constructors for builder leaves and `+ - *` for `b.add/sub/mul`. The +`kind`-enum dispatch in `compute` becomes a `kind`-enum dispatch in `capture`. + +### 4.5 Full rewrite scope (counts verified by grep + reads) + +`grep -rn "impl TransitionConstraint"` across `prover/src/` yields the following +**distinct constraint structs** (each implements the user trait once; structs with a +`kind` enum produce many constraint *instances* but are ONE body to rewrite): + +**`prover/src/constraints/` (11 structs):** +- `templates.rs`: `IsBitConstraint`, `AddConstraint` (+ `AddOperand`/`AddLinearTerm` + helper enums — these get `capture` helpers, not trait impls). +- `cpu.rs`: `ProductZeroConstraint`, `Arg2ExclusiveConstraint`, `MemFlagsBitConstraint`, + `RegNotReadIsZeroConstraint`, `Arg2Constraint`, `RvdEqResConstraint`, + `BranchRvdConstraint`, `BranchCondConstraint`, `NextPcAddConstraint` (+ helper + `res_word`). All small (≤ ~30-line bodies). + +**`prover/src/tables/` (per verified grep, 21 impl sites across 17 files; some files +hold several structs):** +- `mul.rs` `MulConstraint` (kind enum, ~250 lines incl. helpers; convolution). +- `dvrm.rs` `DvrmConstraint` (kind enum, 11 variants; the biggest, ~1300-line file — + body+helpers the largest single rewrite). +- `shift.rs` `ShiftConstraint` (kind enum; ~1100-line file). +- `cpu32.rs` `Cpu32Constraint` (kind enum; ~845-line file). +- `memw.rs` `MemwConstraint`; `memw_aligned.rs` `MemwAlignedConstraint`; + `memw_register.rs` `MemwRegisterMuSumIsBit`. +- `load.rs` `LoadConstraint`; `store.rs` `StoreConstraint`. +- `lt.rs` `LtConstraint`; `eq.rs` `EqXorConstraint`. +- `branch.rs` `BranchConstraint`; `commit.rs` `CommitConstraint`. +- `keccak.rs` `KeccakAddressNoOverflowConstraint` (one small struct). NOTE: keccak's + 51 constraints are mostly **reused `AddConstraint` instances** (`from_dword_bl` + + `constant` + `from_dword_hl`, verified `keccak.rs:545-557`) — so keccak adds almost + no rewrite cost once `AddConstraint::capture` exists, and its program is small (no + GPU register-pressure risk). +- `ec_scalar.rs` `MulZeroConstraint`. +- `ecsm.rs`: `ConvCarry`, `ColIsZero`, `CarryBit`, `OverflowRequired` (4 structs). +- `ecdas.rs`: `ConvCarry`, `ColIsZero`, `MulZero` (3 structs). + +**Authoritative count (enumeration-verified): 33 algebraic +`impl TransitionConstraint` structs across 19 files + 2 framework LogUp +`TransitionConstraintEvaluator` structs (§5).** Breakdown: +- `prover/src/constraints/cpu.rs` (9): ProductZero, Arg2Exclusive, MemFlagsBit, + RegNotReadIsZero, Arg2, RvdEqRes, BranchRvd, BranchCond, NextPcAdd. +- `prover/src/constraints/templates.rs` (2): IsBit, Add (Add carries the + AddOperand/AddLinearTerm combinators — the trickiest single rewrite, §4.2). +- `prover/src/tables/` (22): Branch, Commit, Cpu32, Dvrm, EqXor, MulZero(ec_scalar), + ConvCarry+ColIsZero+MulZero(ecdas, 3), ConvCarry+ColIsZero+CarryBit+OverflowRequired + (ecsm, 4), KeccakAddressNoOverflow, Load, Lt, Memw, MemwAligned, MemwRegisterMuSumIsBit, + Mul, Shift, Store. + +**Scope driver — multi-kind dispatch structs** (one struct, a `kind` enum + a +`compute()` helper; each kind is a separate constraint *instance* needing its own +capture path). Verified kind counts: Dvrm(11), Cpu32(8), Shift(7), Lt(6), Load(6), +Mul(6), Branch(5), Memw(3), MemwAligned(3), Store(2). Dvrm/Cpu32/Shift dominate. +**Rough total evaluate/compute body LOC ≈ 600-800 across the 19 files** — far less +than the raw file sizes suggest, because the kind-enum bodies are short matches that +delegate to `compute()`, and the heavy loops (carry chains, raw-product convolution, +shift formulas) are **statically bounded / metadata-driven**, so they unroll into +builder calls at capture time without per-kind hand-coding of each iteration. +(I read `mul.rs`/`dvrm.rs` bodies in full; the rest share the single-`evaluate` +→`compute()` pattern, kind counts enumeration-verified.) + +--- + +## 5. Rewriting the LogUp / extension framework constraints (Question 4 — the crux) + +These two live in `crypto/stark/src/lookup.rs` and are the only constraints that use +extension arithmetic, challenges, alpha powers, and (for the accumulated one) +next-row reads. They are **not** `TransitionConstraint` impls — they directly +implement the object-safe `TransitionConstraintEvaluator` (`lookup.rs:1741`, +`lookup.rs:1868`). So for these we write `capture` directly on the evaluator impl. +I read both bodies in full; here is how each maps. + +### 5.1 Fingerprint, multiplicity, sign — the shared pieces + +`compute_fingerprint_from_step` (`lookup.rs:1689-1709`) builds +`z − (bus_id + Σ α^k · vₖ)` where `vₖ` are the packed bus elements. In IR: + +```rust +// fingerprint(interaction) -> Expr (Dim3, because z and alpha powers are Dim3) +fn capture_fingerprint(b: &mut IrBuilder, bi: &BusInteraction) -> Expr { + let z = b.challenge(0); // rap_challenges[0] + // α⁰ term: bus_id is a base const, added directly (matches lookup.rs:1697) + let mut lc = b.bus_id(bi.bus_id); // Dim1 const, promoted on first add to Dim3 + let mut alpha_idx = 1usize; // α⁰ handled, start at α¹ (lookup.rs:1698) + for bv in &bi.values { + alpha_idx += capture_busvalue_fingerprint(b, bv, alpha_idx, &mut lc); + } + b.sub(z, lc) // z - lc (Dim3) +} +``` + +`BusValue::accumulate_fingerprint_from_step` (`lookup.rs:738-796`) and +`Packing::accumulate_fingerprint_with` (`lookup.rs:272-369`) are the packing +formulas. They are pure compile-time structure (the `match self { Packing::Word2L => +h0 + 2^16·h1, … }`), so capturing them unrolls the same way as §4.4: + +```rust +fn capture_busvalue_fingerprint(b: &mut IrBuilder, bv: &BusValue, + alpha_off: usize, lc: &mut Expr) -> usize { + match bv { + BusValue::Packed { start_column, packing } => { + // mirror accumulate_fingerprint_with: e.g. Word2L: + // combined = col[start] + col[start+1]·shift_16 (Dim1) + // *lc += combined · alpha_powers[alpha_off] (Dim1 · Dim3 -> Dim3) + let elems = capture_packing(b, *packing, *start_column); // Vec (Dim1) + for (i, e) in elems.iter().enumerate() { + let ap = b.alpha_power(alpha_off + i); // Dim3 + let t = b.mul(*e, ap); // Dim1·Dim3 -> Dim3 + *lc = b.add(*lc, t); + } + packing.num_bus_elements() + } + BusValue::Linear(terms) => { + // result = Σ coeff·col + const (Dim1), then *lc += result·α^alpha_off + let mut r = b.const_base(0); + for t in terms { match t { + LinearTerm::Column{coefficient, column} + | LinearTerm::ColumnUnsigned{coefficient, column}=> { + let col=b.main(0,*column); let k=b.const_signed(*coefficient as i64); + r=b.add(r, b.mul(col,k)); } + LinearTerm::Constant(v)=> r=b.add(r, b.const_signed(*v)), + }} + let ap=b.alpha_power(alpha_off); *lc=b.add(*lc, b.mul(r, ap)); + 1 + } + } +} +``` +> **Honesty note on the runtime zero-skip:** the current code skips the +> `result · α` multiply when `result == 0` *on that row* (`lookup.rs:675-677`, +> `790-792`). That is a *data-dependent* optimization the IR **cannot** reproduce — +> the IR is row-agnostic. The IR always emits the multiply. This is the one place +> the capture approach is strictly less optimal than the current per-row code: a +> few extra D1×D3 muls per row for bus elements that happen to be zero. It does +> **not** change the result (adding `0·α` is a no-op), only cost. Quantify in +> validation; likely negligible vs. the dispatch savings. (✓ VERIFIED the skip +> exists and is value-preserving.) + +`compute_multiplicity_from_step` (`lookup.rs:1679-1684`) = `Multiplicity::evaluate_with` +(`lookup.rs:1252-1282`): `One→1`, `Column→col`, `Sum→a+b`, `Negated→1-col`, +`Diff→a-b`, `Sum3→a+b+c`, `Linear→Σ`. All Dim1, captured as add/sub/mul chains. + +**The sign** (`is_sender`) is a **compile-time bool on the interaction**, so it is +resolved during capture by choosing `add` vs `sub` (or wrapping in `neg`) — never an +IR value. This matches the current "conditional negation instead of E×E sign +multiplication" (`lookup.rs:1779-1790`). + +### 5.2 `LookupBatchedTermConstraint::capture` (was `lookup.rs:1754-1831`) + +Formula (verified `lookup.rs:1791`): +`c·fp_a·fp_b − sign_a·m_a·fp_b − sign_b·m_b·fp_a`. + +```rust +fn capture(&self, b: &mut IrBuilder) { + let c = b.aux(0, self.term_column_idx); // Dim3 + let fp_a = capture_fingerprint(b, &self.interaction_a); + let fp_b = capture_fingerprint(b, &self.interaction_b); + let m_a = capture_multiplicity(b, &self.interaction_a.multiplicity); // Dim1 + let m_b = capture_multiplicity(b, &self.interaction_b.multiplicity); + let term_a = b.mul(m_a, fp_b); // Dim1·Dim3 -> Dim3 + let term_a = if self.interaction_a.is_sender { term_a } else { b.neg(term_a) }; + let term_b = b.mul(m_b, fp_a); + let term_b = if self.interaction_b.is_sender { term_b } else { b.neg(term_b) }; + let main = b.mul(b.mul(c, fp_a), fp_b); + let root = b.sub(b.sub(main, term_a), term_b); + b.emit(self.constraint_idx, root); +} +``` +Clean. Degree 3, all Dim3 at the top, exactly mirrors the read body. + +### 5.3 `LookupAccumulatedConstraint::capture` (was `lookup.rs:1881-2005`) — the messy one + +This is the only constraint that reads **two row offsets** (`acc_curr` at offset 0, +`acc_next` and the term columns at offset 1) — verified +`first_step.get_aux(0, acc)` / `second_step.get_aux(0, …)` where `first_step = +frame.get_evaluation_step(0)` and `second_step = frame.get_evaluation_step(1)` +(`lookup.rs:1971-1972`, `1899-1905`). The IR addresses next-row values with +`b.aux(1, col)` — this is exactly why `Op::Main/Aux` carry an `offset: u8` and why +the program records `max_offset` (the interpreter must fill a 2-step frame for these +tables; the prover already builds frames with `offsets = [0,1]`, see +`AirWithBuses` context `transition_offsets: vec![0,1]`, `lookup.rs:909`). + +```rust +fn capture(&self, b: &mut IrBuilder) { + let acc_curr = b.aux(0, self.acc_column_idx); // offset 0 + let acc_next = b.aux(1, self.acc_column_idx); // offset 1 <-- next row + // terms_sum over committed term columns at offset 1 (lookup.rs:1903) + let mut terms = b.const_base(0); + for i in 0..self.num_term_columns { terms = b.add(terms, b.aux(1, i)); } + // delta = acc_next - acc_curr - terms_sum + L/N + let off = b.table_offset(); // logup_table_offset (Dim3) + let delta = b.add(b.sub(b.sub(acc_next, acc_curr), terms), off); + match self.absorbed.len() { + 1 => { // delta·f - sign·m (lookup.rs:1932) + let f = capture_fingerprint_at(b, &self.absorbed[0], /*offset*/1); + let m = capture_multiplicity_at(b, &self.absorbed[0].multiplicity, 1); + let mt = if self.absorbed[0].is_sender { m } else { b.neg(m) }; + let root = b.sub(b.mul(delta, f), mt); + b.emit(self.constraint_idx, root); + } + 2 => { // delta·f1·f2 - sign1·m1·f2 - sign2·m2·f1 (lookup.rs:1957) + let f1=capture_fingerprint_at(b,&self.absorbed[0],1); + let f2=capture_fingerprint_at(b,&self.absorbed[1],1); + let m1=capture_multiplicity_at(b,&self.absorbed[0].multiplicity,1); + let m2=capture_multiplicity_at(b,&self.absorbed[1].multiplicity,1); + let t1=b.mul(m1,f2); let t1=if self.absorbed[0].is_sender{t1}else{b.neg(t1)}; + let t2=b.mul(m2,f1); let t2=if self.absorbed[1].is_sender{t2}else{b.neg(t2)}; + let root=b.sub(b.sub(b.mul(b.mul(delta,f1),f2),t1),t2); + b.emit(self.constraint_idx, root); + } + _ => unreachable!(), + } +} +``` +> **The messiness, stated honestly:** +> 1. `capture_fingerprint`/`capture_multiplicity` need an **offset parameter** because +> the absorbed interactions read columns at the *next* row (`second_step`, +> `lookup.rs:1919-1946`), whereas the batched-term constraint reads the *current* +> row. The fingerprint/packing capture helpers (§5.1) must thread `offset: u8` +> through to every `b.main(offset, …)`/`b.aux(offset, …)`. This is a real but +> mechanical generalization (one extra arg). +> 2. The `1` vs `2` absorbed cases have different degree (2 vs 3) and different +> formulas; both must be captured (matches the existing `match absorbed.len()`). +> 3. `logup_table_offset` becomes the `TableOffset` uniform leaf (§8). It is `L/N`, +> a single Dim3 value computed in `ConstraintEvaluator::new` (`evaluator.rs:199`) +> and passed via the context — already a per-proof uniform. +> +> **Verdict:** LogUp maps to the builder *cleanly but with one wart* — the per-row +> zero-skip (§5.1) is lost, and the fingerprint helpers must be offset-parameterized. +> Neither blocks the approach; both are mechanical. This is materially **less messy** +> than fighting `IsField` to make a shadow-field type carry the same z/α/alpha-power +> uniforms through `compute_fingerprint_from_step`'s generic `>` +> signature (the sibling approach's burden). The deciding factor leans toward this +> approach because the capture is a near-verbatim transcription of the existing, +> already-factored helpers. + +--- + +## 6. CPU interpreter & the boundary (Question 5) + +### 6.1 Where it slots in + +The boundary is exactly `air.compute_transition_prover` (prover, `traits.rs:254`) +and `air.compute_transition` (verifier, `traits.rs:223`). Today both loop over +`transition_constraints()` calling `evaluate_prover`/`evaluate_verifier`. After the +rewrite, `AirWithBuses` (the only production AIR, `lookup.rs:964`) overrides both to +call the interpreter against its stored `TableProgram`: + +```rust +fn compute_transition_prover(&self, ctx, base_evals, ext_evals) { + interpret_prover(&self.program, ctx, base_evals, ext_evals); +} +fn compute_transition(&self, ctx) -> Vec> { + let mut ext = vec![FieldElement::zero(); self.num_transition_constraints()]; + interpret_verifier(&self.program, ctx, &mut ext); + ext +} +``` + +`ConstraintEvaluator::evaluate_transitions` (`evaluator.rs:79-135`) is **unchanged**: +it still calls `air.compute_transition_prover(&ctx, base_buf, transition_buf)` +(`evaluator.rs:100`) and accumulates with zerofiers (`evaluator.rs:102-134`). The IR +sits entirely inside the AIR's override. + +### 6.2 Base vs ext handling — two interpreters, shared walk + +- **Prover** frame is `Frame`: `main` reads are **Dim1** + (base), `aux` reads are **Dim3**. So `interpret_prover` evaluates each node into + either a `u64` (D1) or `[u64;3]` (D3) slot. The first `num_base` constraints are + D1-rooted and written into `base_evals: &mut [FieldElement]`; the rest are + D3-rooted into `ext_evals[num_base..]`. This reproduces the existing F×E split + (`evaluator.rs:104-114`, `transition.rs:439-458`). Verified: base constraints + must be the first `num_base_transition_constraints()` and the LogUp constraints + are appended last (`lookup.rs:857`, `traits.rs:244`). +- **Verifier** frame is `Frame`: there is no base field; every value + is Dim3 (the verifier "works with a frame that contains only elements from the + extension", `traits.rs:69-71`). So `interpret_verifier` runs the *same* node walk + but treats `Main` reads as Dim3 (the column value is already an ext element) and + every op as Dim3. The IR's per-node `dim` is the prover's typing; the verifier + simply promotes D1 leaves to D3. One IR, two interpreters differing only in leaf + loading and whether D1 storage is used. + +Implementation: a value arena `Vec` where `enum Val { D1(u64), D3([u64;3]) }`, +or two parallel arenas (`Vec` for D1 ids, `Vec<[u64;3]>` for D3 ids) keyed by +node dim. Arithmetic dispatches on `(dim(a),dim(b))` using the raw Goldilocks ops +(`GoldilocksField::add/mul`, the cubic-ext formulas). Reuse the per-thread buffer +pattern already in `evaluate_transitions` (`map_init`, `evaluator.rs:142`): the value +arena is a per-thread scratch `Vec` sized to `program.nodes.len()`. + +### 6.3 Fate of `TransitionConstraintAdapter` (Question 5) + +**End state:** `TransitionConstraint::evaluate` and the adapter's +`evaluate_prover`/`evaluate_verifier` are **deleted**. The user trait keeps +`degree/constraint_idx/period/offset/exemptions/end_exemptions` + the new `capture`. +`TransitionConstraintEvaluator` keeps the zerofier/degree/index methods + `capture`, +and **drops** `evaluate_prover`/`evaluate_verifier` (the per-row eval path no longer +goes through the trait object — it goes through the interpreter). The adapter shrinks +to a forwarder for `capture` and the metadata methods. + +**Transitional:** during migration we keep both `evaluate*` and `capture` so the old +per-row path and the new interpreter can run in parallel and be diff'd +(§9, §12). Only after every table validates bit-for-bit do we delete the old methods. + +--- + +## 7. GPU interpreter sketch (Question 7) + +Model on the `gpu_lde.rs` seam: TypeId checks gate entry, `repr(transparent)`/`[u64;3]` +layout lets us reinterpret `FieldElement` slices as raw `u64`, and a `_keep` device +handle holds the LDE columns resident from R1. + +- **Entry/dispatch.** A new `try_compute_transition_gpu(program, lde_trace, uniforms)` + guarded by `TypeId::of::()==Goldilocks && TypeId::of::()==Ext3` and an + lde-size threshold (mirror `check_base_layout`, `gpu_lde.rs:106`). Returns + `Option>>` of length `num_transition · lde_size` (the per-row + `Cᵢ` values), or `None` to fall back to the CPU interpreter. It is called from the + AIR's `compute_transition_prover` analog — but note the current + `evaluate_transitions` calls `compute_transition_prover` *per row*; for GPU we add a + batched override that produces all rows at once and feeds the accumulation loop + (this is a small refactor of `evaluate_transitions` to optionally accept a + precomputed `Cᵢ` matrix; the accumulation stays on whichever side is cheaper). + `✗ UNCERTAIN`: exact placement of the batched call (per-row vs whole-table) needs a + design pass — the cleanest is a new `air.compute_transitions_batched(lde) -> + Option` that `evaluate_transitions` tries before the per-row loop. +- **What crosses the boundary (once per table).** The program blob (`GpuOp[]` + + `consts_d1` + `consts_d3`), the uniforms (challenges, alpha_powers, periodic + columns, table_offset, packing shifts-as-consts). The LDE main/aux columns are + already on device (`lde_trace.gpu_main()`/`gpu_aux()`, `gpu_lde.rs:832,915`). No + per-row H2D. +- **Kernel.** One `interpret_transition_ext3` kernel, one thread per LDE row + (strided like `barycentric_*_strided`). Each thread walks `nodes` left-to-right + into a small per-thread register/local array indexed by NodeId (program is tiny — + hundreds of nodes — fits in local/shared memory), loading `Main/Aux` from the + resident LDE at `(row + offset·stride)`, doing D1/D3 ops with the existing device + primitives (`gl_add/gl_mul/gl_sub` and `ext3_add/ext3_mul/ext3_sub`, verified + present `device.rs:124-131`). Writes `Cᵢ` for each emit. Because the program is + uniform across rows, this is an embarrassingly parallel single-field kernel — the + whole point of the IR. New `.cu` file `transition_interp.cu` + `Backend` field + + `load_function` (mirror `device.rs:227-229`). +- **Fallback.** Any unsupported op/dim, sub-threshold size, or non-Goldilocks → CPU + interpreter (identical IR, identical result). Same `Option`-returning contract as + every `try_*` in `gpu_lde.rs`. + +--- + +## 8. Inputs plumbing (Question 6) + +The interpreter needs the per-proof/per-row uniforms that today live in +`TransitionEvaluationContext` (`traits.rs:72-93`). They become **leaf opcodes** read +from a uniform table the interpreter is handed alongside the program: + +| Current source (verified) | IR leaf | Const-vs-varies | +|---|---|---| +| `periodic_values[j]` (`evaluator.rs:88-90`, filled per row) | `Op::Periodic{j}` | varies per row (Dim1) | +| `rap_challenges[i]` (`ctx`, `traits.rs:80`) | `Op::Challenge{i}` | per proof (Dim3) | +| `logup_alpha_powers[k]` (precomputed `evaluator.rs:53`) | `Op::AlphaPow{k}` | per proof (Dim3) | +| `logup_table_offset` (`evaluator.rs:199`, `traits.rs:82`) | `Op::TableOffset` | per proof (Dim3) | +| `packing_shifts` (8/16/24, `lookup.rs:53`) | `Op::ConstD1` | program constant | + +The interpreter signature: +```rust +fn interpret_prover(prog: &TableProgram, ctx: &TransitionEvaluationContext, + base: &mut [FieldElement], ext: &mut [FieldElement]); +``` +pulls `frame`, `periodic_values`, `rap_challenges`, `logup_alpha_powers`, +`logup_table_offset` straight out of `ctx` (already plumbed through +`evaluate_transitions`, `evaluator.rs:92-99`). **No new plumbing into the +evaluator** — the context already carries everything; we only change what *consumes* +it. For GPU, these uniforms upload once per table (challenges/alpha/offset are +per-proof; periodic is `num_periodic · lde_size` Dim1, uploaded once). + +--- + +## 9. Coexistence & migration (Question 9) + +- **Table-by-table migration is fully supported.** The interpreter dispatch is on the + AIR. We add `capture` to all constraints up front (it can default to a `todo!()` + or, better, a generic auto-capture, see below), but flip an AIR to *use* the + interpreter independently. Concretely, `AirWithBuses` gets an `Option`: + when `Some`, `compute_transition_prover` interprets; when `None`, it falls back to + the existing `transition_constraints().iter()…evaluate_prover` loop (the current + `traits.rs:267-269` default). So a table is "migrated" by building its program in + `AirWithBuses::new`; unmigrated tables keep the old path verbatim. +- **Auto-capture bridge (optional but valuable):** because every algebraic body is + data-independent, we *could* provide a blanket `capture` that runs the existing + generic `evaluate` against a recording `TableView` whose elements are IR nodes — + i.e. a `TableView` where `IrField` is a field-like type whose + `add/mul` push IR nodes. **However** that is precisely the "shadow IsField" trick + the sibling approach owns, and making `IrField: IsField` is the trait-tower fight + we're avoiding. So for *this* approach we hand-write `capture` per struct and do + **not** rely on an auto-bridge. (Mentioned for completeness; explicitly rejected + here to keep the approaches distinct.) +- **Feature/TypeId gating:** GPU path behind the existing `cuda` feature + TypeId + guard (no new feature). CPU interpreter is unconditional. A `LAMBDA_VM_USE_IR` + env/feature can force the old path for A/B benchmarking during migration. + +--- + +## 10. Exhaustive file-by-file change list + +**New files:** +- `crypto/stark/src/ir.rs` — `Dim`, `NodeId`, `Op`, `Node`, `Expr`, `IrBuilder` + (full API §3.2), `TableProgram`, const/CSE dedup. `~400 LOC`. +- `crypto/stark/src/interpreter.rs` — `interpret_prover`, `interpret_verifier`, + `Val` arena, op dispatch, D1/D3 raw arithmetic helpers. `~300 LOC`. +- `crypto/math-cuda/src/transition_interp.rs` + `cuda/transition_interp.cu` — GPU + kernel + host wrapper `compute_transition_ext3`. `~400 LOC + kernel`. +- `crypto/stark/src/gpu_transition.rs` — `try_compute_transition_gpu` dispatch + (TypeId guard, blob upload, fallback). `~250 LOC`. (Or fold into `gpu_lde.rs`.) + +**Modified — framework:** +- `crypto/stark/src/constraints/transition.rs`: + - `TransitionConstraint`: add `fn capture(&self, &mut IrBuilder)`; delete + `evaluate` (end state). + - `TransitionConstraintEvaluator`: add object-safe `fn capture(&self, &mut + IrBuilder)`; delete `evaluate_prover`/`evaluate_verifier` (end state). + - `TransitionConstraintAdapter`: forward `capture`; drop `evaluate_*`. +- `crypto/stark/src/lookup.rs`: + - `LookupBatchedTermConstraint`: replace `evaluate_verifier` body with `capture` + (§5.2). `LookupAccumulatedConstraint`: replace with `capture` (§5.3). + - Add offset-parameterized capture helpers mirroring + `compute_fingerprint_from_step` (1689), `compute_multiplicity_from_step` (1679), + `BusValue::accumulate_fingerprint_from_step` (738), `Packing::accumulate_*` (272), + `Multiplicity::evaluate_with` (1252). + - `AirWithBuses`: add `program: Option`; build it in `new` + (`lookup.rs:848`) by `capture`-ing every constraint after assembly; override + `compute_transition_prover`/`compute_transition` to interpret. +- `crypto/stark/src/traits.rs`: optionally add + `fn compute_transitions_batched(&self, lde) -> Option>` default `None` + (GPU batched hook for `evaluate_transitions`). +- `crypto/stark/src/constraints/evaluator.rs`: (optional) try the batched GPU hook + before the per-row loop; otherwise **unchanged**. +- `crypto/stark/src/lib.rs` / `crypto/math-cuda/src/lib.rs`: module decls. + +**Modified — every constraint struct (`capture` body, delete `evaluate`):** +- `prover/src/constraints/templates.rs`: `IsBitConstraint`, `AddConstraint`, + `AddOperand::capture_lo/hi`, `AddLinearTerm::capture`, `eval_terms`→`capture`. +- `prover/src/constraints/cpu.rs`: `ProductZeroConstraint`, `Arg2ExclusiveConstraint`, + `MemFlagsBitConstraint`, `RegNotReadIsZeroConstraint`, `Arg2Constraint`, + `RvdEqResConstraint`, `BranchRvdConstraint`, `BranchCondConstraint`, + `NextPcAddConstraint`, `res_word`→capture helper. +- `prover/src/tables/`: `mul.rs (MulConstraint+compute helpers)`, `dvrm.rs + (DvrmConstraint)`, `shift.rs (ShiftConstraint)`, `cpu32.rs (Cpu32Constraint)`, + `memw.rs (MemwConstraint)`, `memw_aligned.rs (MemwAlignedConstraint)`, + `memw_register.rs (MemwRegisterMuSumIsBit)`, `load.rs (LoadConstraint)`, + `store.rs (StoreConstraint)`, `lt.rs (LtConstraint)`, `eq.rs (EqXorConstraint)`, + `branch.rs (BranchConstraint)`, `commit.rs (CommitConstraint)`, + `keccak.rs (one struct)`, `ec_scalar.rs (MulZeroConstraint)`, + `ecsm.rs (ConvCarry, ColIsZero, CarryBit, OverflowRequired)`, + `ecdas.rs (ConvCarry, ColIsZero, MulZero)`. + +**Key new type/function signatures (summary):** +```rust +pub struct TableProgram { nodes, consts_d1, consts_d3, emits, num_base, … } +pub struct IrBuilder { … } impl IrBuilder { main/aux/const_*/periodic/challenge/alpha_power/table_offset/add/sub/mul/neg/emit/finish } +pub fn interpret_prover(&TableProgram, &TransitionEvaluationContext, &mut[FE], &mut[FE]); +pub fn interpret_verifier(&TableProgram, &TransitionEvaluationContext, &mut[FE]); +trait TransitionConstraint { fn capture(&self, &mut IrBuilder); … } // generic evaluate removed +trait TransitionConstraintEvaluator { fn capture(&self, &mut IrBuilder); … } // evaluate_* removed +pub(crate) fn try_compute_transition_gpu(&TableProgram, &LDETraceTable, …) -> Option>>; +``` + +--- + +## 11. Risks & unknowns, ranked (brutally honest) + +1. **Breadth of the manual rewrite (33 structs / 19 files, ~600-800 LOC of bodies).** + This is the dominant cost and risk. Every body is mechanical but the multi-kind + mega-constraints (`dvrm` 11 kinds, `cpu32` 8, `shift` 7) have many capture paths + that are easy to transcribe subtly wrong. *Mitigation:* the bit-for-bit + parallel-path validation (§12) catches any divergence immediately; migrate one + table at a time behind the `Option` flag. +2. **LogUp `LookupAccumulatedConstraint` offset handling + lost per-row zero-skip.** + The fingerprint helpers must thread `offset` (next-row reads, §5.3) and the IR + cannot do the data-dependent `result==0` multiply-skip (§5.1). Correctness is + safe (value-preserving); the cost is a few extra D1×D3 muls/row. *Risk:* the + skip might matter more than expected on wide-bus tables; measure before deleting + the old path. `? INFERRED` it's negligible vs. dispatch savings — not yet + measured. +3. **Verifier-side typing (`Main` reads are Dim3 in the verifier).** The IR's + per-node `dim` is the prover's; the verifier interpreter must promote D1 leaves + to D3 and run everything as D3. If any constraint body relied on F-specific + behavior (e.g. `inv()` in base field) this would break — but I verified the + algebraic bodies only use `+ - * ` and `const` (the only "division" is + multiply-by-precomputed-`INV_SHIFT_32` const, `templates.rs:30`, which is just a + `Mul` by a constant — safe in any field). ✓ VERIFIED no body calls `inv()` at + eval time. +4. **GPU kernel program-size / divergence.** Programs are small (hundreds of nodes) + and uniform across rows (no divergence), but the per-thread value arena must fit + in registers/local mem; a large mega-constraint program (`dvrm`/`shift` are the + biggest) could spill. Keccak is NOT a concern (mostly reused `AddConstraint`s, + verified). *Mitigation:* per-thread arena lives in shared/local mem indexed by + NodeId; CPU fallback always available; GPU is opt-in per table above a threshold. +5. **Refactor of `evaluate_transitions` for the batched GPU hook.** The current loop + is per-row (`evaluator.rs:79`); a whole-table GPU call needs either a batched + path or accepting a precomputed `Cᵢ` matrix. `✗ UNCERTAIN` on the cleanest seam; + CPU interpreter needs none of this (it slots into the existing per-row call). +6. **CSE/const-dedup correctness.** Optional, but if the CSE key mis-merges two ops + with the same shape but different dim, results corrupt. *Mitigation:* key on + `(Op, Dim)`; or ship without CSE first (correctness independent of it). + +--- + +## 12. Effort estimate & validation strategy + +### Effort (by workstream) +- **IR + builder + CPU interpreter (framework):** `ir.rs` + `interpreter.rs` + + trait changes + `AirWithBuses` wiring. **~4-5 days.** Highest design value; + unblocks everything. +- **Rewrite algebraic constraints (33 structs, ~600-800 LOC):** 11 small structs in + `constraints/` (~1 day) + 22 in `tables/`, of which ~10 are multi-kind dispatch + structs. Budget the small ones at ~10-15/day; the multi-kind ones at ~0.5-1.5 each + (dvrm/cpu32/shift the costliest): **~4-6 days** total. +- **LogUp framework (2 constraints + offset-parameterized helpers):** **~2-3 days** + — small count but the highest per-line care (the crux). +- **Validation harness (parallel old/new diff):** **~1-2 days.** +- **GPU interpreter (kernel + dispatch + batched hook):** **~5-7 days** incl. the + `evaluate_transitions` batched seam and parity tests. Can land *after* the CPU + path is fully migrated and validated. +- **Total: ~2.5-3.5 weeks** for CPU-complete + validated; +1-1.5 weeks for GPU. + +### Validation (bit-for-bit, real tables, parallel paths) +1. **Keep the old generic `evaluate*` alongside `capture` during migration.** In a + `#[cfg(test)]` / debug harness, for each table and each LDE row, run BOTH: + the old `compute_transition_prover` (current trait-object loop) and + `interpret_prover(program, …)`, then `assert_eq!` the full `base_evals` and + `ext_evals` arrays. This is exactly the existing `validate_trace`-style + debug-assert pattern referenced in project memory; here it asserts + *evaluator equality* not trace validity. +2. **Drive it with the existing prove test** (`cargo test --release -p + lambda-vm-prover test_prove_elfs_test_sb_sh_8`) and the per-table bus tests + (`prover/src/tests/*_bus_tests.rs`, `*_tests.rs`) — these already exercise every + table's full constraint set on real traces. A mismatch pinpoints the exact + constraint_idx and row. +3. **Verifier parity:** at the OOD point, diff `air.compute_transition` (old) vs + `interpret_verifier` for the same frame — small (one frame), cheap, catches the + D1→D3 promotion bugs (Risk 3). +4. **GPU parity:** standard `gpu_lde.rs` pattern — compute on GPU and on CPU + interpreter, assert equal (the math-cuda test suite already does this per kernel; + add a `transition_interp` parity test). +5. Because the old path *coexists* (Option flag), CI can run both and assert equality + on every prove until we delete the old methods — zero-risk cutover. + +### What I could not confirm +- Struct count (33 algebraic + 2 LogUp / 19 files) and per-struct kind counts are + enumeration-verified; the ~600-800 LOC body total is an aggregate estimate (I read + `mul.rs`/`dvrm.rs` in full; the rest share the `kind`-enum→`compute()` pattern). +- Whether any table reads periodic columns (none of the bodies I read did; the + `Periodic` leaf is provided for completeness — `get_periodic_column_values` + defaults to empty, `traits.rs:290`). `? INFERRED` periodic is unused by current + tables. +- The cleanest `evaluate_transitions` seam for the batched GPU call (Risk 5). +- Keccak constraint body size/shape (didn't read it) — flagged for GPU register + pressure (Risk 4). diff --git a/thoughts/gpu-constraint-eval/plan-symbolic-field.md b/thoughts/gpu-constraint-eval/plan-symbolic-field.md new file mode 100644 index 000000000..1b4c4abad --- /dev/null +++ b/thoughts/gpu-constraint-eval/plan-symbolic-field.md @@ -0,0 +1,893 @@ +# Plan: GPU-ready constraint evaluation via a "Symbolic field" capture + +> **Status:** the CPU spike from this plan is **implemented** (PR #737, branch +> `spike/constraint-ir-symfield`). For the as-built state and the detailed, +> checkbox continuation plan, see **[`roadmap.md`](./roadmap.md)** — that is the +> execution / handoff doc. This file remains the full design rationale. + +**Approach:** keep the ~29 constraint bodies UNCHANGED; introduce a recording +field type `SymField`/`SymExt` whose field operations build an expression graph +instead of computing. Run each constraint's existing generic +`evaluate::(...)` (and the LogUp helpers) ONCE at setup to +capture a flat single-field Goldilocks IR, then INTERPRET that IR on CPU (prover +over the LDE coset + verifier at the OOD point) and on GPU (one universal +Goldilocks interpreter kernel). + +All file/line references below were read directly from the current tree. + +--- + +## 1. Overview & end-state + +After this change, each `AIR` (per table) owns, in addition to its existing +`Vec>`, a captured **constraint program**: +a flat list of typed Goldilocks IR ops plus a per-constraint root id. The program +is built once, at AIR construction, by running every constraint through a +recording field (`SymField`/`SymExt`) and recording the LogUp framework +constraints (`LookupBatchedTermConstraint`, `LookupAccumulatedConstraint`) via +the same recording field. At evaluation time, an **interpreter** walks the IR: +on CPU it replaces the per-row `air.compute_transition_prover(...)` call inside +`ConstraintEvaluator::evaluate_transitions` (crypto/stark/src/constraints/evaluator.rs:100) +and the verifier's `air.compute_transition(...)` call +(crypto/stark/src/verifier.rs:209); on GPU it is one Goldilocks kernel that +reads the serialized IR plus the device-resident LDE columns and produces the +per-constraint `Cᵢ` values. The accumulation `Σ βᵢ·Cᵢ·Zᵢ⁻¹ + boundary` and all +zerofier/coefficient machinery stay exactly where they are in +`evaluate_transitions` — the IR only replaces the step that produces each +constraint's scalar `Cᵢ`. + +``` + ┌─ capture (ONCE, at AIR::new, concrete types known) ─┐ +constraint structs ──► run evaluate::(sym_step) │ +LogUp framework ──► run evaluate_batched/accumulated::(...) │ + records into thread-local arena ──► ConstraintProgram │ + └────────────────────────────────────────────────────┘ + │ (serialize) + ┌─────────────────────────────────┼───────────────────────────────┐ + CPU prover (per LDE row) CPU verifier (1 OOD point) GPU kernel + interp(program, frame) ─► Cᵢ interp(program, ood_frame) ─► Cᵢ interp over device cols + │ │ │ + └─► Σ βᵢ·Cᵢ·Zᵢ⁻¹ (unchanged accumulation in evaluate_transitions / verifier) +``` + +The boxed `dyn TransitionConstraintEvaluator` path is retained verbatim as a +fallback and as the differential-test oracle (Section 9, 12). + +--- + +## 2. The IR (concrete Rust data structures) + +The IR is **single-field over Goldilocks**, with a dimension tag distinguishing +base (`dim1`, one u64) from extension (`dim3`, three u64). New crate module: +`crypto/stark/src/symbolic/ir.rs`. + +```rust +/// Field-arithmetic dimension of a node's value. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum Dim { D1, D3 } // base Goldilocks, or its degree-3 extension + +/// A leaf input slot, resolved by the interpreter against the current frame +/// and the per-proof uniform inputs. +#[derive(Clone, Copy, Debug)] +pub enum Leaf { + /// Main trace column read: step.data[row][col], offset selects frame step. + Main { step: u8, row: u8, col: u16 }, // dim1 (base) for prover, dim3 for verifier + /// Aux trace column read: step.aux_data[row][col]. + Aux { step: u8, row: u8, col: u16 }, // always dim3 + /// Periodic column value at this row. + Periodic { idx: u16 }, // dim1 + /// rap_challenges[idx] (z, alpha, ...) + Rap { idx: u16 }, // dim3 + /// logup_alpha_powers[idx] + AlphaPow { idx: u16 }, // dim3 + /// logup_table_offset + TableOffset, // dim3 + /// One of the three precomputed packing shift constants (2^8, 2^16, 2^24) + Shift { which: u8 }, // dim1 (prover) / dim3 (verifier) +} + +/// One IR instruction. Indices are u32 ids into the program's `nodes` arena. +#[derive(Clone, Copy, Debug)] +pub enum Op { + Const1(u64), // dim1 literal (from FieldElement::from(u64/i64), one(), zero()) + Const3([u64; 3]), // dim3 literal (rare: produced by to_extension / from(u64) in E) + Leaf(Leaf), + Add(u32, u32), + Sub(u32, u32), + Mul(u32, u32), + Neg(u32), + // Embed a dim1 value into dim3 (the to_extension() / IsSubFieldOf::embed step, + // and the implicit base→ext promotion that F×E ops perform). + Embed(u32), +} + +/// A captured program for one table's transition constraints. +pub struct ConstraintProgram { + pub nodes: Vec, // topologically ordered (id i only references < i) + pub dims: Vec, // dims[i] = result dimension of nodes[i] + pub roots: Vec, // roots[c] = node id of constraint c's value Cᵢ + pub num_base: usize, // first num_base roots are dim1 (base-field) constraints + // metadata needed to size interpreter input arrays: + pub max_step: u8, pub max_main_col: u16, pub max_aux_col: u16, +} +``` + +**Typing rule.** Every node carries a `Dim`. `Add/Sub/Mul` of (D1,D1)→D1; +any operand D3 ⇒ result D3 (the interpreter auto-`Embed`s the D1 operand, +matching the `F: IsSubFieldOf` mixed-arithmetic the field tower performs at +crypto/math/src/field/element.rs:344). `Embed(D1)→D3`. This mirrors the real +arithmetic exactly: a base×ext multiply is 3 Goldilocks muls (the +`IsSubFieldOf::mul` at crypto/math/src/field/extensions_goldilocks.rs:413), an +ext×ext multiply is one `dot_product_3` schoolbook (extensions_goldilocks.rs:297). + +**Serialization for GPU.** `nodes` is encoded as a packed `Vec` opcode +stream: `[opcode_tag, operand_a, operand_b]` (3×u32 per node; `Const1`/`Const3` +store their literal in a side `Vec` indexed by operand_a). `dims` is a +`Vec`. `roots` is a `Vec`. This is a flat POD layout that copies to the +device as three buffers (`ops: &[u32]`, `consts: &[u64]`, `roots: &[u32]`), +following the same "reinterpret as `&[u64]`/`&[u32]`, transmute-free POD" +discipline used by the GPU LDE bridge in crypto/stark/src/gpu_lde.rs. + +--- + +## 3. Capture front-end — `SymField` design (the distinguishing section) + +`SymField` is a **marker type** that implements `IsField`, exactly like +`GoldilocksField` is a zero-sized marker whose `BaseType = u64` +(crypto/math/src/field/goldilocks.rs:70-73). The constraint bodies are generic +over the *field marker* `F` and operate on `FieldElement`, whose data is +`F::BaseType` (crypto/math/src/field/element.rs:50-52). So we choose: + +```rust +pub struct SymField; // base-field recorder (dim1) +pub struct SymExt; // extension recorder (dim3) +impl IsField for SymField { type BaseType = SymId; ... } +impl IsField for SymExt { type BaseType = SymId; ... } +``` + +where `SymId` wraps a `u32` node id (plus the `Dim` it denotes, see arena +decision). Because `BaseType` is just an id, every `IsField::add/mul/...` call +*records* a node into a thread-local arena and returns a fresh id. + +### Q1 — ARENA PROBLEM: thread-local arena returning u32 ids (chosen) + +`IsField` ops are static, contextless `fn mul(a: &BaseType, b: &BaseType) -> BaseType` +(crypto/math/src/field/traits.rs:104-112). There is no `&self`/arena parameter +to thread. Two options: + +* **`BaseType = Arc` (tree, hash-consed).** Each op allocates an `Arc` + node holding its children `Arc`s. Dedup requires hash-consing through a + thread-local `HashMap>`. *Downsides:* Arc clone/drop traffic + during capture, recursion in `Drop` for deep trees, and we *still* need a + thread-local for the hash-cons table — so it buys nothing over ids while + costing pointer-chasing and an `Arc` per node. Rejected. + +* **Thread-local arena returning `u32` ids (CHOSEN).** A `thread_local!` arena: + + ```rust + thread_local! { + static ARENA: RefCell> = const { RefCell::new(None) }; + } + struct Arena { + nodes: Vec, + dims: Vec, + cse: HashMap, // hash-consing: (opcode + operand ids) → id + } + ``` + + `BaseType` is a small `Copy` struct: + + ```rust + #[derive(Clone, Copy, Debug, Default)] + pub struct SymId { id: u32, dim: Dim } // Default = id 0 ... see Q2 Default note + ``` + + Each op does `ARENA.with(|a| { let a = a.borrow_mut().as_mut().unwrap(); + a.push(Op::Mul(x.id, y.id)) })` where `push` consults `cse` for dedup + (hash-consing gives a DAG, not a tree, for free). Capture is wrapped in + `with_arena(|| { ... run constraints ...; arena.take() })` which installs a + fresh `Arena`, runs the closure, and extracts `(nodes, dims, roots)`. + + This avoids `Arc` entirely, gives DAG dedup via the `cse` map, is `Copy` + (so `.clone()` in constraint bodies — used heavily, e.g. templates.rs:97, + cpu.rs:147 — is free and correct), and the only state lives in one + `thread_local`. Capture runs single-threaded per program (it's a setup-time + one-shot per table), so the thread-local is uncontended. **This is the + pick.** + + Hash-consing is mandatory, not optional: without it the ADD-carry templates + (templates.rs:414-440, `compute_carry_1` recomputes `compute_carry_0`) and + the LogUp fingerprints (each `compute_fingerprint_from_step` re-reads the same + columns) would blow up the node count. With `cse`, `compute_carry_0`'s subtree + is shared. + +### Q2 — TRAIT-METHOD COVERAGE (exhaustive) + +`SymField` must satisfy `IsField` and the `BaseType` bounds. `SymExt` must +satisfy `IsField`. The `IsSubFieldOf for SymField` impl is also needed +because constraint bodies are bounded `F: IsSubFieldOf` and `evaluate` +returns `FieldElement` (transition.rs:352-355). Below, every required method +with its symbolic implementation or a flag. + +**`IsField for SymField` (BaseType = SymId, dim D1):** + +| Method | Symbolic impl | +|---|---| +| `type BaseType = SymId` | id+dim, Copy | +| `add(a,b)` | record `Add(a,b)` → D1 | +| `sub(a,b)` | record `Sub(a,b)` → D1 | +| `mul(a,b)` | record `Mul(a,b)` → D1 | +| `neg(a)` | record `Neg(a)` → D1 | +| `double(a)` | default `add(a,a)` works; or record `Add(a,a)` | +| `square(a)` | default `mul(a,a)` works | +| `zero()` | record/return `Const1(0)` id (default `BaseType::default()`; see note) | +| `one()` | record/return `Const1(1)` id | +| `from_u64(x)` | record `Const1(GoldilocksField::from_u64(x))` id | +| `from_base_type(x)` | identity (return x) | +| `inv(a)` | **PROBLEM if ever called** — emit `Op::Inv` only if needed; NOT used by any algebraic constraint nor by the LogUp framework constraints (verified: no `.inv()`/`.div()`/`.pow()` in prover/src/constraints/; LogUp clears denominators so the constraint bodies never invert — fingerprints are *subtracted/multiplied*, not divided, in `evaluate_batched_term_constraint` lookup.rs:1759 and `evaluate_accumulated_constraint` lookup.rs:1887). **Make `inv` `unimplemented!("symbolic inv")`** — if capture ever hits it we want a loud failure, not silent wrong IR. | +| `div(a,b)` | same: `unimplemented!()` (not reached) | +| `eq(a,b)` | **SUBTLE** — returns `bool`, can't be symbolic. Used by `result != FieldElement::zero()` short-circuits in lookup.rs:675, lookup.rs:790. Must return a **conservative `false`** so the "skip zero term" optimization is *not* taken during capture (we always record the multiply). See Q5; this is correct because the skip is a runtime data optimization, and the captured IR must be data-independent. | +| `pow`, `sqrt`, `legendre_symbol` | not reached; default impls call `mul`/`square` and would work but should never run. | + +**`BaseType: Clone + Debug + ByteConversion + Default + Send + Sync` +(traits.rs:101):** + +| Bound | For `SymId` | +|---|---| +| `Clone + Copy` | derive (it's `{u32, Dim}`) | +| `Debug` | derive | +| `Default` | derive — **but** `Default` is used by `FieldElement::default()` → `value: F::zero()` (element.rs:488) and by `Frame::preallocate` (frame.rs:90-95). During capture we *don't* call `preallocate`; we build a symbolic frame by hand (Q4). A derived `SymId::default()` = `{id:0,dim:D1}` is fine as long as id 0 is a valid node — we reserve node id 0 = `Const1(0)` so a stray default is the zero element. **Resolved, no problem.** | +| `Send + Sync` | `SymId` is `Copy` POD ⇒ auto. The thread-local arena is not part of `SymId`, so no `Send` issue. | +| `ByteConversion` | **FLAG — must implement but never call.** `write_bytes_be/to_bytes_be/from_bytes_be/from_bytes_le` ⇒ `unimplemented!()`. ByteConversion is only exercised by transcript/serialization paths (goldilocks.rs:436), which capture never touches. Acceptable: it's a trait-bound satisfier, not a live method. | + +**`IsField for SymExt` (BaseType = SymId, dim D3):** identical table, but every +recorded node is tagged D3, and `from_u64(x)` records `Const3([from_u64(x),0,0])` +(matching `Degree3...::from_u64` extensions_goldilocks.rs:399). `one()`→ +`Const3([1,0,0])`, `zero()`→`Const3([0,0,0])`. `inv`/`div` `unimplemented!()`. + +**`IsSubFieldOf for SymField` (traits.rs:17-25):** this is the mixed +base×ext arithmetic surface the field-element operators dispatch through +(element.rs:223,295,346). Each must record the correct mixed node: + +| Method | Symbolic impl | +|---|---| +| `mul(a: &SymId/*D1*/, b: &SymId/*D3*/) -> SymId/*D3*/` | record `Mul(a,b)` tagged D3 (the interpreter sees a D1×D3 mul and does the 3-mul base×ext path) | +| `add(a,b) -> D3` | record `Add(a,b)` D3 | +| `sub(a,b) -> D3` | record `Sub(a,b)` D3 | +| `div(a,b)` | `unimplemented!()` (not reached) | +| `embed(a: SymId/*D1*/) -> SymId/*D3*/` | record `Embed(a)` → D3 | +| `to_subfield_vec(b)` | `unimplemented!()` (not reached; only serialization uses it) | + +Note the blanket `impl IsSubFieldOf for F` (traits.rs:27-60) automatically +gives us `IsSubFieldOf for SymField` and `IsSubFieldOf for +SymExt` (the prover-frame `evaluate` with FF=F and the verifier-frame with +FF=E both rely on these reflexive impls). Those route to `SymField::mul` etc., +so no extra code. + +**`IsFFTField for SymField`?** The `AIR` trait bounds `Field: IsFFTField` +(traits.rs:139) and `AirWithBuses` further bounds `Field: IsPrimeField` +(lookup.rs:805). **But capture does NOT instantiate any `AIR`.** +Capture only calls the *constraint object's* generic `evaluate::(step)` and the LogUp helper fns `::` — those are +bounded only `FF: IsSubFieldOf, EE: IsField` (transition.rs:352-355, +lookup.rs:1759, lookup.rs:1887). So `SymField` needs **only** `IsField + +IsSubFieldOf`, NOT `IsFFTField`/`IsPrimeField`. This is the single most +important feasibility fact: it sidesteps `IsFFTField::{TWO_ADICITY, root, +field_name}` and `IsPrimeField::{canonical, from_hex, field_bit_size}` entirely +(none are reachable from `evaluate`). *Verified:* `evaluate`'s only bound is +`FF: IsSubFieldOf` (transition.rs:354), the LogUp inner fns' +only bound is `A: IsSubFieldOf, B: IsField` (lookup.rs:1759, lookup.rs:1887, +lookup.rs:1679, lookup.rs:1689). Capture never builds the AIR with sym types. + +### Q3 — Constants & `to_extension` / `one()` / `zero()` + +* `FieldElement::::from(i64/u64)` → `From`/`From` (element.rs:136,149) + → `F::from_u64(value)`. For `F = SymField` this records `Const1(c)` with + `c = GoldilocksField::from_u64(value)` (we *fold the real Goldilocks reduction* + at capture time so the literal stored is canonical). `i64` negatives go through + `-Self::from(abs)` (element.rs:157) → records `Neg(Const1(abs))`; or we can + constant-fold to `Const1(p - abs)`. Either is correct; constant-folding negatives + keeps the IR smaller. Examples captured this way: `inv_2_32` (templates.rs:30-36, + a `from(INV_SHIFT_32)`), `SHIFT_16` (cpu.rs:69), `AddLinearTerm` coefficients + `1<<16`, `1<<8`, `1<<24` (templates.rs:266-326), bus `LinearTerm` coefficients + (lookup.rs:656,772). +* `FieldElement::one()`/`zero()` (element.rs:550,556) → `F::one()`/`F::zero()`. + For `SymField` → `Const1(1)`/`Const1(0)`; for `SymExt` → `Const3([1,0,0])`/ + `Const3([0,0,0])`. The literal `FieldElement::::one()` appears all over + (templates.rs:98, cpu.rs:146). +* `.to_extension::()` (element.rs:566) → `>::embed(value)`. + Used by the adapter's verifier path `...evaluate(...).to_extension()` + (transition.rs:431). For `F=SymField, L=SymExt` this records `Embed(child)`. + **However** — see Section 4: in the *prover* capture we run the adapter with + FF=F (base), and in the *verifier* capture we run FF=E (already D3); we will + capture the constraint's **base value** (the `evaluate` result, dim D1) and let + the interpreter/accumulator handle the embed, mirroring how the real prover + keeps base constraints in `base_evals: &mut [FieldElement]` + (evaluator.rs:106-110). So `to_extension` is mostly *not* in the captured graph + for base constraints; it only appears if a constraint body itself calls + `to_extension`, which none of the algebraic ones do (they return D1). + +### Q4 — SYMBOLIC FRAME + +Capture needs a `TableView` (and `Frame`) +whose column reads return `Leaf` nodes. `TableView` is +`{ data: Vec>>, aux_data: Vec>> }` +(table.rs:397-399) and reads go through `get_main_evaluation_element(row, col)` +(table.rs:410) / `get_aux_evaluation_element` (table.rs:414). So we build a +symbolic frame by filling each cell with a `FieldElement::from_raw(SymId)` whose +id is a recorded `Leaf::Main { step, row, col }` / `Leaf::Aux { ... }`: + +```rust +fn symbolic_frame(num_steps, rows_per_step, num_main, num_aux) -> Frame { + let steps = (0..num_steps).map(|step| { + let data = (0..rows_per_step).map(|r| + (0..num_main).map(|c| + FieldElement::::from_raw(record_leaf(Leaf::Main{step,row:r,col:c})) + ).collect()).collect(); + let aux_data = (0..rows_per_step).map(|r| + (0..num_aux).map(|c| + FieldElement::::from_raw(record_leaf(Leaf::Aux{step,row:r,col:c})) + ).collect()).collect(); + TableView::new(data, aux_data) + }).collect(); + Frame::new(steps) +} +``` + +`num_steps` = `offsets.len()` (= 2 for LogUp tables, `transition_offsets: +vec![0,1]` lookup.rs:909). `rows_per_step` = step_size/blowup (1 for these +tables). The two `TransitionEvaluationContext` variants needed for capture +(Q5/Q6): a `Prover { frame: &Frame, periodic_values: +&[FieldElement], rap_challenges: &[FieldElement], +logup_alpha_powers, logup_table_offset, packing_shifts: &PackingShifts }`. +Each uniform input (periodic, rap, alpha pow, table offset, shifts) is also a +recorded `Leaf` (`Periodic`, `Rap`, `AlphaPow`, `TableOffset`, `Shift`). The +shift constants `PackingShifts::::new()` (lookup.rs:54) call +`FieldElement::::from(SHIFT_8/16)` and `&shift_8 * &shift_16` — those +record `Const1` + `Mul` automatically; but to keep the IR clean we instead +construct `PackingShifts` with `Leaf::Shift{0/1/2}` ids so the interpreter +injects the real precomputed constants at eval time (they're loop-invariant and +the existing code precomputes them once, lookup.rs:64). Both are correct; the +`Leaf::Shift` version matches the existing precompute and keeps shifts uniform. + +--- + +## 4. Capturing the algebraic constraints (the ~29 structs, via the adapter) + +The ~29 algebraic constraints implement the user-facing `TransitionConstraint` +trait and are wrapped by `TransitionConstraintAdapter` (transition.rs:393). +Their bodies are generic `evaluate(&self, step: &TableView) -> +FieldElement` (transition.rs:352). **We do NOT touch any body.** Capture +calls each constraint's `evaluate::(sym_step)` directly and +reads the returned `SymId` (the root for that constraint). + +**Count, verified by grep: there are 33 (not ~29) algebraic +`impl TransitionConstraint` structs**, not +just the CPU ones the team-lead memo listed. Beyond templates.rs/cpu.rs they +span prover/src/tables/: branch.rs:519, commit.rs:837, cpu32.rs:645, +dvrm.rs:1219, ec_scalar.rs:291, ecdas.rs:{363,402,426}, ecsm.rs:{663,698,791,816}, +eq.rs:262, keccak.rs:503, load.rs:572, lt.rs:536, memw_aligned.rs:708, +memw_register.rs:388, memw.rs:921, mul.rs:847, shift.rs:914, store.rs:282 +(plus the 11 in templates.rs/cpu.rs). The "zero body edits" win therefore +applies to **all 33**, including the large ones (keccak, ecsm, dvrm, mul) — a +bigger payoff than the memo implied, but those large bodies also drive risk 5/6 +(node count / GPU scratch). + +These constraints all return a base-field (`FF=F`) value, so we capture them as +**dim-D1 roots** placed in `roots[0..num_base]`, matching the prover's base +split (evaluator.rs:50, evaluator.rs:106). **Safe-op audit (first-hand + grep, +load-bearing for feasibility):** every body uses only `clone`, `+`, `-`, `*`, +`neg` (via `-x`), `FieldElement::from`, `one()`, `get_main_evaluation_element` — +e.g. `IsBitConstraint` (templates.rs:92-107: `&cond * &x * (one - x)`), +`AddConstraint` (templates.rs:442-467 + the carry helpers, which multiply by the +constant `inv_2_32`), `ProductZeroConstraint` (cpu.rs:105-112), `Arg2Constraint` +(cpu.rs:277-303), `BranchRvdConstraint`/`NextPcAddConstraint` +(cpu.rs:394-446, cpu.rs:518-571). Crucially, a grep over the **entire** +`prover/src` (non-test) finds **zero** `.inv()`/`.pow()`/`.div()`/`.sqrt()`/ +`.legendre_symbol()` calls and **zero** field-value conditionals (`== FieldElement`, +`.is_zero()`, `if …value()…`) across all 17 table files — so no body, and no +helper any body transitively calls, performs division/inversion/exponentiation +or branches on a field value. (The per-struct degree + body summary from the +enumeration sub-agent is appended at the end.) + +**Framework glue changes** (minimal, additive): + +1. New trait method on `TransitionConstraint` with a default that **panics**, and + override it for the adapter is *not* the route — instead add a free function + `capture_user_constraint>(c: &T, step: &TableView) -> SymId` + that just calls `c.evaluate::(step)`. Because the adapter + stores `T` (transition.rs:393, `TransitionConstraintAdapter(pub T)`), but + the AIR only keeps `Box` (lookup.rs:813), + we cannot recover `T` from the boxed object. **Therefore capture must run at + the point where concrete constraint types still exist — i.e. inside each + table's constraint-list builder** (e.g. `create_all_cpu_constraints` + cpu.rs:619), *before* `.boxed()`. See Section 9. + +2. Add a capture entry point to the `TransitionConstraintEvaluator` trait: + `fn capture(&self, ctx: &SymCaptureCtx) -> SymId;` with a default that calls + `evaluate_verifier` against a symbolic context... **but** `evaluate_verifier` + needs `&mut [FieldElement]` slots, and for the adapter it calls + `self.0.evaluate(...).to_extension()` (transition.rs:431). Running *that* + under sym types records the constraint plus a trailing `Embed`, giving a D3 + root. That is acceptable for capture purposes (the embed is a no-op cost on + D1→D3 and the accumulator can treat the root as D3). **This is the cleaner, + object-safe route:** add `fn capture(&self, ctx, &mut [SymId])` to + `TransitionConstraintEvaluator`, default-implemented by calling a sym version + of `evaluate_verifier`. The adapter's `capture` runs + `self.0.evaluate::(frame.step).to_extension()` → D3 root, OR + (better, to keep base/ext split) runs `evaluate` and stores the D1 root for + `idx < num_base`. We implement the latter: `capture` mirrors `evaluate_prover` + (transition.rs:439) — D1 root into base slot for base constraints, D3 for the + LogUp ones. This keeps the captured program's `num_base` aligned with + `air.num_base_transition_constraints()` (lookup.rs:1025). + +The recommended concrete design: add **one** method +`fn capture(&self, ctx: &SymCaptureContext, base_roots: &mut Vec, +ext_roots: &mut Vec)` to `TransitionConstraintEvaluator` +(crypto/stark/src/constraints/transition.rs). Default impl: run the verifier- +style body symbolically and push a D3 root. Adapter override +(transition.rs:395): run `self.0.evaluate::` and push a D1 root +when `idx < base_roots.capacity-marker`, else D3. The two LogUp framework +structs override `capture` to run their `evaluate_*_constraint` inner fns under +sym types (Section 5). + +--- + +## 5. Capturing the LogUp / extension framework constraints (Q5 — the crux) + +The two LogUp constraints do **not** go through the adapter; they +`impl TransitionConstraintEvaluator` directly and branch on the +`TransitionEvaluationContext` enum (lookup.rs:1741, lookup.rs:1868). The decisive +question: **are their helpers generic enough to run under SymField/SymExt?** + +**Verdict: YES — they are fully capturable, no hand-emit needed.** Evidence: + +* `compute_multiplicity_from_step, B: IsField>` (lookup.rs:1679) + — generic; body is `multiplicity.evaluate_with(|col| step.get_main_evaluation_element(0,col).clone())` + → `Multiplicity::evaluate_with` (lookup.rs:1252) uses only `one()`, `+`, `-`, + `*`, `FieldElement::from(coeff)`. All recordable. ✓ +* `compute_fingerprint_from_step, B: IsField>` (lookup.rs:1689) + — generic; body builds `FieldElement::::from(bus_id)` then loops + `bv.accumulate_fingerprint_from_step(...)` (lookup.rs:738) which uses + `get_main_evaluation_element`, `Packing::accumulate_fingerprint_with` + (lookup.rs:272: only `+`,`*`, shift consts), and `z - &linear_combination`. + All recordable. ✓ +* `evaluate_batched_term_constraint, B: IsField>` + (lookup.rs:1759) — generic inner fn; computes `c * fp_a * fp_b - term_a - + term_b`. ✓ +* `evaluate_accumulated_constraint, B: IsField>` + (lookup.rs:1887) — generic; `delta * f - m*sign` etc. ✓ + +**The two sign/branch subtleties, and why they're still capturable:** + +1. **`is_sender` sign logic** (lookup.rs:1780-1790, lookup.rs:1927-1932, + lookup.rs:1954-1956): these are `if interaction.is_sender { term } else { + -term }` — branching on a **compile-time-known `bool` field of the + interaction struct**, NOT on a field *value*. During capture `is_sender` is a + concrete `bool`, so the branch is resolved at capture time and we record + either `term` or `Neg(term)`. ✓ No data dependence. + +2. **`result != FieldElement::::zero()` short-circuit** in + `accumulate_fingerprint_from_step` (lookup.rs:790) and the column-major + variant (lookup.rs:675): this branches on a *field value* via `PartialEq` → + `F::eq`. For `SymField` we make `eq` return **`false` always** (Q2), so the + capture path *always records the multiply* (`*acc += result * alpha_powers[..]`). + This is **correct and conservative**: the skip is a runtime optimization for + rows where the value happens to be zero; the IR must be valid for *all* rows, + so it must include the multiply unconditionally. The interpreter then always + does the multiply — slightly more work than the optimized CPU path on + all-zero rows, but bit-identical results. ✓ (If we wanted to preserve the + optimization we could detect "operand is a `Const1(0)` node" at capture time + and constant-fold, recovering the bus-id-padding skip statically. Recommended + as a cheap IR peephole.) + +**Building the capture context.** We construct +`TransitionEvaluationContext::Prover { frame: &Frame, +rap_challenges: &[FieldElement], logup_alpha_powers: +&[FieldElement], logup_table_offset: &FieldElement, +packing_shifts: &PackingShifts, periodic_values: +&[FieldElement] }` (the enum at traits.rs:77-84). Every slice element +is a `Leaf` node (`Rap{idx}`, `AlphaPow{idx}`, `TableOffset`, `Shift{}`). The +frame has 2 steps (acc uses `frame.get_evaluation_step(0)` and `(1)`, +lookup.rs:1972-1973). We call the constraint's `evaluate_verifier` (or the new +`capture`) with this Prover context; the matching `match` arm +(lookup.rs:1794, lookup.rs:1963) fires the generic inner fn under sym types and +returns a D3 root. **No fallback hand-emit is required** — this is the key win +over a hand-written LogUp IR. + +One caveat to call out: `evaluate_verifier` writes into `transition_evaluations: +&mut [FieldElement]` (lookup.rs:1827). Under sym types `E=SymExt`, so the +result is a `FieldElement` whose value is the root `SymId` — we read it +back from the slot. The slice must be pre-filled with a sentinel; we size it to +`num_transition_constraints` and read `slot[constraint_idx]` after the call. ✓ + +--- + +## 6. CPU interpreter + +New module `crypto/stark/src/symbolic/interp.rs`. Two entry points, one shared +core. + +**Core:** `fn eval_program(prog: &ConstraintProgram, inputs: &Inputs, out: &mut Outputs)` +walks `prog.nodes` in id order, computing each node into a value array. Because +nodes are topologically ordered (id i references < i) a single forward pass with +a `Vec` (len = nodes.len()) suffices; `Value` is an enum +`{ D1(FieldElement), D3(FieldElement) }` with auto-embed on mixed ops. +`inputs` resolves `Leaf`s: `Main`/`Aux` from the current frame step/row/col, +`Periodic/Rap/AlphaPow/TableOffset/Shift` from the per-proof uniform arrays +(Section 8). Final: `out.base[c] = values[roots[c]]` for `c` (verifier.rs:198, `into_frame`) so *all* +reads are D3; we run `eval_program` with an `Inputs` whose `Main` leaves resolve +to the OOD frame's D3 cells (interpreter reads them as D3 directly — the program +is the same, only the leaf-resolution dimension differs). Output is the +`Vec>` consumed by the zerofier fold (verifier.rs:218-225), +untouched. + +**Base/ext handling.** The interpreter must do D1×D3 the cheap way (3 muls, +matching `IsSubFieldOf::mul` extensions_goldilocks.rs:413) and D3×D3 via +`dot_product_3` (one `Degree3...::mul` extensions_goldilocks.rs:297). We reuse +the real `FieldElement` / `FieldElement` arithmetic +inside `Value`, so the interpreter's per-node cost equals the boxed path's — the +IR overhead is just the opcode dispatch (a `match` per node), which is cheap +relative to a Goldilocks mul. For the prover the program is run with +`F=GoldilocksField, E=Degree3...`; for the verifier with `F=E=Degree3...`. + +--- + +## 7. GPU interpreter sketch + +One universal Goldilocks kernel, modeled on the gpu_lde TypeId+transmute seam. + +**Host seam** (`crypto/stark/src/symbolic/gpu_interp.rs`): a +`try_eval_program_gpu(prog, lde_trace, uniforms, out) -> Option<()>` that, +exactly like check_base_layout (gpu_lde.rs:106) / the barycentric dispatchers +(gpu_lde.rs:811), gates on `TypeId::of::() == GoldilocksField` and +`TypeId::of::() == Degree3...` (gpu_lde.rs:826-831), a size threshold, and a +device-resident main/aux LDE handle (`lde_trace.gpu_main()`/`gpu_aux()`, +gpu_lde.rs:832,915). On mismatch → `None` → CPU interpreter fallback. The +program's three POD buffers (`ops: &[u32]`, `consts: &[u64]`, `roots: &[u32]`, +Section 2) plus the uniform arrays (rap challenges, alpha powers, table offset, +periodic columns, shift consts — all reinterpreted to `&[u64]` via the same +`#[repr(transparent)]` pattern as weights_to_u64 gpu_lde.rs:196) are H2D-copied +**once** (they're constant across all LDE rows). The columns are already on +device from the R1 LDE keep-handles (gpu_lde.rs:459, `GpuLdeBase`/`GpuLdeExt3`). + +**Device kernel** (new file under crypto/math-cuda/src/, e.g. +`symbolic_interp.cu` + a `math_cuda::symbolic` Rust wrapper): one thread per LDE +row. Each thread allocates a small per-node scratch in registers/shared/local +memory (`nodes.len()` Goldilocks values — programs are small, ~hundreds of nodes +per table) and runs the same forward pass as the CPU core, using the existing +math-cuda Goldilocks device primitives: base mul/add/sub (the same reduce128 +identities as goldilocks.rs:197), and ext3 mul as device `dot_product_3` +(mirroring goldilocks.rs:290). The kernel writes `out[row*num_constraints + c]` +for each root. **What crosses the host/device boundary:** program buffers + uniforms +(small, once); columns (already resident); output = `num_constraints × lde_size` +ext3 values (or, with the base/ext split, `num_base × lde_size` base + the rest +ext3) — D2H once. The accumulation `Σ βᵢ·Cᵢ·Zᵢ⁻¹` can stay on host (cheap) or be +fused into the kernel later; for v1 keep it on host to minimize surface, matching +how `apply_ext3_scalar` post-processes on host (gpu_lde.rs:694). + +The single-field design means **one kernel** handles every table — the per-table +difference is entirely in the data buffers (`ops/consts/roots`), so there is no +per-table CUDA codegen. This is the whole point of the interpreter approach. + +--- + +## 8. Inputs plumbing (Q6) + +Periodic values, rap_challenges, logup_alpha_powers, logup_table_offset, and +packing_shifts vary **per proof** but are **constant across all rows** of one +table's evaluation. They are already computed once per `evaluate_transitions` +call: `logup_alpha_powers` (evaluator.rs:53), `packing_shifts` (evaluator.rs:64), +`rap_challenges` (passed in), `logup_table_offset` (evaluator.rs:47), +`lde_periodic_columns` (evaluator.rs:251 — note periodic is **per-row**, indexed +by `col[i]`, so it is a row-varying leaf resolved like a column). They become IR +**leaf inputs** with these resolutions in the interpreter's `Inputs`: + +| Leaf | CPU resolution | GPU resolution | +|---|---|---| +| `Main{step,row,col}` | `frame.get_evaluation_step(step).get_main_evaluation_element(row,col)` | device LDE main column, strided by step·lde_step_size (frame.fill_from_lde logic, frame.rs:117) | +| `Aux{...}` | `...get_aux_evaluation_element` | device LDE aux column | +| `Periodic{idx}` | `periodic_buf[idx]` (= `lde_periodic_columns[idx][i]`) | device periodic column | +| `Rap{idx}` | `rap_challenges[idx]` | uniform buffer slot | +| `AlphaPow{idx}` | `logup_alpha_powers[idx]` | uniform buffer slot | +| `TableOffset` | `logup_table_offset` | uniform buffer slot | +| `Shift{which}` | `packing_shifts.{shift_8,16,24}` | uniform buffer slot | + +At capture time, the leaf *indices* (which rap challenge, which alpha power) are +fixed by how the constraint reads them (`rap_challenges[0]` = z, lookup.rs:1769; +`alpha_powers[alpha_offset]` walked in packing, lookup.rs:294). So the program +encodes the exact indices; the interpreter just gathers from the per-proof arrays. +The arrays' lengths are known at eval time (`max_bus_elements` → +`compute_alpha_powers` count, evaluator.rs:55). No re-capture per proof. + +--- + +## 9. Coexistence & object-safety + +* **Where capture runs.** Because the AIR only stores `Box` (lookup.rs:813) and the adapter erases the + concrete `T` (transition.rs:393), the cleanest object-safe route is to add a + `capture` method to the **`TransitionConstraintEvaluator` trait** (which the + boxed objects *do* expose). The adapter's `capture` (transition.rs:395) calls + `self.0.evaluate::` — concrete `T` is in scope there. The two + LogUp structs override `capture` to run their generic inner fns. Then a single + pass over `air.transition_constraints()` (the existing `Vec>`, + traits.rs:314) captures the whole program. This means **the AIR builds its + `ConstraintProgram` once in a new default method** + `AIR::constraint_program(&self) -> ConstraintProgram` that iterates the boxed + constraints and calls `capture` on each within a `with_arena` scope. No table + builder needs editing. + +* **`capture` and object safety.** Adding `fn capture(&self, ctx: + &SymCaptureContext, base: &mut Vec, ext: &mut Vec)` to the + trait keeps it object-safe (no generics in the method signature; `SymField`/ + `SymExt` are concrete). The default impl runs the verifier-shaped body + symbolically. ✓ + +* **Generic boxed path retained as fallback.** `compute_transition_prover` + (traits.rs:254) and `compute_transition` (traits.rs:223) stay. A feature flag + `symbolic-interp` (or a runtime toggle) selects, inside + `evaluate_transitions` (evaluator.rs:100) and the verifier (verifier.rs:209), + whether to call the IR interpreter or the boxed path. Default off until the + differential test (Section 12) is green; then default on. + +* **TypeId gating for GPU.** The GPU interpreter only engages for the real + `GoldilocksField`/`Degree3...` instantiation (Section 7), identical to + gpu_lde.rs:119-152. For any other field the host code transparently uses the + CPU interpreter or the boxed path. + +* **Cache the program.** `ConstraintProgram` is built once per AIR and stored in + the AIR (or in `ConstraintEvaluator::new`, evaluator.rs:188, alongside + `boundary_constraints`). It is immutable and `Send + Sync` (POD), so it's + shared across all Rayon workers and reused across proofs of the same table + shape. + +--- + +## 10. Exhaustive file-by-file change list + +**New files:** + +* `crypto/stark/src/symbolic/mod.rs` — module root, re-exports. +* `crypto/stark/src/symbolic/sym_field.rs` — + `pub struct SymField; pub struct SymExt; #[derive(Clone,Copy,Default,Debug)] + pub struct SymId{id:u32,dim:Dim}`; `impl IsField for SymField/SymExt`; + `impl IsSubFieldOf for SymField`; `impl ByteConversion for SymId` + (unimplemented stubs); the `thread_local! ARENA` + `with_arena` + + `record(Op)->SymId` (hash-consing) + `record_leaf(Leaf)->SymId`. +* `crypto/stark/src/symbolic/ir.rs` — `Dim`, `Leaf`, `Op`, `ConstraintProgram`, + serialization (`to_pod()` → `(Vec, Vec, Vec)`). +* `crypto/stark/src/symbolic/capture.rs` — `SymCaptureContext` + (builds the symbolic `Frame`/`TableView`/uniform leaves, Q4), + `fn capture_program(constraints: &[Box>], + layout, num_base, ...) -> ConstraintProgram`. +* `crypto/stark/src/symbolic/interp.rs` — `Value`, `Inputs`, `Outputs`, + `fn eval_program(prog,&Inputs,&mut Outputs)` (CPU core + prover & verifier + adapters). +* `crypto/stark/src/symbolic/gpu_interp.rs` — `try_eval_program_gpu(...) + -> Option<()>` (TypeId gate + H2D uniforms + kernel launch + D2H), guarded by + the cuda feature. +* `crypto/math-cuda/src/symbolic_interp.rs` (+ `.cu`) — `math_cuda::symbolic:: + eval_program_*` device wrapper and the one universal Goldilocks/ext3 kernel. + +**Modified files:** + +* `crypto/stark/src/constraints/transition.rs` — add + `fn capture(&self, ctx: &SymCaptureContext, base: &mut Vec, ext: &mut + Vec)` to `TransitionConstraintEvaluator` (default impl runs verifier- + shaped body symbolically); override in `TransitionConstraintAdapter` + (transition.rs:395) to run `self.0.evaluate::`. +* `crypto/stark/src/lookup.rs` — override `capture` for + `LookupBatchedTermConstraint` (lookup.rs:1741) and + `LookupAccumulatedConstraint` (lookup.rs:1868) to run their generic inner fns + under sym types. The inner fns are **unchanged** (already generic). +* `crypto/stark/src/traits.rs` — add default method + `fn constraint_program(&self) -> ConstraintProgram` (iterates + `self.transition_constraints()` + `with_arena`). +* `crypto/stark/src/constraints/evaluator.rs` — in `evaluate` (evaluator.rs:216) + build/fetch the cached `ConstraintProgram`; in `evaluate_transitions` + (evaluator.rs:100) replace `air.compute_transition_prover(&ctx, base_buf, + transition_buf)` with `eval_program(...)` (behind the feature/toggle), with the + GPU dispatch tried first (gpu_interp `try_eval_program_gpu`, else CPU). +* `crypto/stark/src/verifier.rs` — at verifier.rs:209 replace + `air.compute_transition(&ctx)` with the verifier interpreter (same toggle). +* `crypto/stark/src/lib.rs` (or `mod.rs`) — `pub mod symbolic;`. +* `crypto/math/src/field/...` — **no change** (SymField lives in the stark + crate; it only needs the public `IsField`/`IsSubFieldOf` traits, which are + already public). If `ByteConversion` for `SymId` must be impl'd where the + trait is defined due to orphan rules, add a thin impl in math; otherwise keep + in stark (SymId is a stark type, ByteConversion is a math trait — the impl is + allowed in stark since SymId is local). ✓ orphan-rule-safe. + +**Key new signatures:** +```rust +impl IsField for SymField { type BaseType = SymId; fn mul(a:&SymId,b:&SymId)->SymId {record(Op::Mul(a.id,b.id),Dim::D1)} ... } +impl IsSubFieldOf for SymField { fn mul(a:&SymId,b:&SymId)->SymId {record(Op::Mul(a.id,b.id),Dim::D3)} fn embed(a:SymId)->SymId{record(Op::Embed(a.id),Dim::D3)} ... } +pub fn capture_program(cs: &[Box>], layout:(usize,usize), num_base:usize, offsets:&[usize], step_size:usize) -> ConstraintProgram; +pub fn eval_program(prog:&ConstraintProgram, inp:&Inputs<'_>, out:&mut Outputs<'_>); +pub(crate) fn try_eval_program_gpu(prog:&ConstraintProgram, lde:&LDETraceTable, uni:&Uniforms, out:&mut [FieldElement]) -> Option<()>; +``` + +--- + +## 11. Risks & unknowns, ranked + +1. **IsField-contract friction is LOW — feasibility CONFIRMED.** The decisive + finding: capture never instantiates `AIR`, only calls + `evaluate::` and the LogUp inner fns, whose bounds are only + `IsSubFieldOf + IsField` (transition.rs:354, lookup.rs:1759/1887/1679/1689). + So `SymField` needs **no** `IsFFTField`/`IsPrimeField` — the + `TWO_ADICITY`/root/canonical/from_hex methods are unreachable. The remaining + `IsField` methods that can't be symbolic (`inv`, `div`, `eq`) are either never + reached (`inv`/`div`: no division in any constraint body, verified by grep + + reading templates.rs/cpu.rs/lookup.rs) or handled by a conservative `eq=false` + (the only `eq` use is a runtime zero-skip optimization, lookup.rs:675/790, + which capture must *not* take). `ByteConversion`/`to_subfield_vec` are + bound-satisfier stubs that never run. **Residual risk:** a future constraint + body that calls `.inv()`/`.pow()`/branches on a field value would panic at + capture; mitigate with the loud `unimplemented!()` + a CI check. + +2. **LogUp capturability is HIGH-confidence YES.** The helpers are already + generic over `A: IsSubFieldOf, B` (lookup.rs:1679/1689) and the constraint + inner fns too (lookup.rs:1759/1887); `is_sender` is a compile-time bool, not a + field value (lookup.rs:1780); the only field-value branch is the zero-skip, + handled by `eq=false`. So **no hand-emit of LogUp IR is needed** — this is the + approach's biggest advantage over hand-writing. **Residual risk:** the `eq` + short-circuit means the captured IR always multiplies even by `Const1(0)` + bus-padding; mitigate with a constant-fold peephole (detect `Mul(_,Const(0))`/ + `Add(x,Const(0))` at capture) so the IR matches the optimized path's node + count and the GPU kernel doesn't waste lanes. Low effort, high value. + +3. **Bit-for-bit equivalence of the interpreter vs the boxed path.** The + interpreter reuses the real `FieldElement` arithmetic, so per-op results are + identical; the risk is in *order of operations* (field add/mul are + associative/commutative in value but the existing code's specific fold order + is what the OOD/LDE evaluations must match for the proof to verify). Since we + capture the *exact* call sequence the body executes (recording in evaluation + order), the IR's forward-pass order equals the body's order. **Residual + risk:** the zero-skip fold (lookup.rs:672) changes the *additive grouping* on + zero rows; with `eq=false` we always add, which is value-identical (adding 0). + So equivalence holds. Validate empirically (Section 12). + +4. **Capture-time arena correctness with hash-consing.** A wrong `cse` key + (e.g. not distinguishing D1 vs D3 nodes with the same operands) would alias + nodes of different dimension. Mitigate: include `Dim` in the `cse` key, or + never CSE across dims. Low risk, but must be tested. + +5. **GPU per-thread scratch pressure.** Each thread needs `nodes.len()` + Goldilocks values live. If a table's program is large (hundreds of nodes × + ext3 = hundreds × 24 bytes), register/shared pressure could limit occupancy. + Mitigate: liveness analysis to reuse scratch slots (a node's value is dead + after its last use), or spill to local memory. This is a perf risk, not a + correctness risk, and v1 can keep the accumulation on host. Medium. + +6. **Unknown: exact node counts per table.** Not yet measured, and there are + **33** algebraic constraints across many tables — the largest bodies + (keccak.rs, ecsm.rs, dvrm.rs, mul.rs, commit.rs) are big polynomial circuits + and will dominate node count. ADD/LogUp with hash-consing should be small (low + hundreds), but keccak/ecsm could be thousands of nodes, directly amplifying + risk 5 (GPU per-thread scratch). Resolve by instrumenting `capture_program` + to print `nodes.len()` per table during the differential test, and prioritize + liveness-based scratch reuse for the large tables. + +**Overall feasibility verdict: HIGH.** The SymField approach is sound; the +IsField-contract friction is manageable (the unreachable-method insight is the +crux) and LogUp captures cleanly with zero hand-emit. + +--- + +## 12. Effort estimate & validation strategy + +**Effort (engineer-days, by workstream):** + +* W1 — `SymField`/`SymExt`/`SymId` + arena + hash-consing + IsField/IsSubFieldOf + impls + stubs: **2–3 d**. (Mechanical once the unreachable-method set is fixed.) +* W2 — IR types + serialization + capture context (symbolic frame/uniforms) + + `capture` trait method + adapter/LogUp overrides + `AIR::constraint_program`: + **3–4 d**. (LogUp override is the fiddly part but the inner fns are unchanged.) +* W3 — CPU interpreter (core + prover slot in evaluator.rs + verifier slot) + + feature toggle: **3–4 d**. +* W4 — Differential test harness + peephole constant-fold + fix discrepancies: + **2–3 d**. +* W5 — GPU host seam + universal kernel + math-cuda wrapper + D2H wiring: + **6–10 d** (the largest and riskiest; v1 keeps accumulation on host). + +**Total: ~16–24 engineer-days**, with W1–W4 (~10–14 d) delivering a working, +validated CPU IR interpreter and W5 the GPU kernel. + +**Validation strategy (bit-for-bit, on a real table):** + +1. **Per-row prover diff.** In `evaluate_transitions` (evaluator.rs:79), for each + LDE row run BOTH `air.compute_transition_prover(&ctx, base_a, ext_a)` and + `eval_program(prog, ..., base_b, ext_b)`, and `assert_eq!` the base and ext + buffers element-by-element. Gate behind a `debug-checks`-style feature so it's + on in tests, off in production. Run against the existing test + `cargo test --release -p lambda-vm-prover test_prove_elfs_test_sb_sh_8` + (from project memory) for the CPU table (which exercises ADD/IS_BIT/LogUp). +2. **Per-constraint verifier diff.** At verifier.rs:209 compare + `air.compute_transition(&ctx)` vs the verifier interpreter at the single OOD + point; `assert_eq!` the full `Vec>`. Cheapest oracle (one + point). +3. **End-to-end.** With the interpreter as the live path, run the full + prove→verify test suite; a passing verify is the strongest equivalence check + (the composition poly and FRI depend on every `Cᵢ`). Run across all tables + (CPU, MEMW, LOAD, DECODE, MUL, BRANCH, REGISTER, PAGE, BITWISE, LT, HALT) so + every constraint shape and every `Packing`/`Multiplicity` variant is covered. +4. **GPU diff.** Compare `try_eval_program_gpu` output against the CPU + interpreter output (not the boxed path) element-wise, reusing the + `test-cuda-faults` style harness (gpu_lde.rs:1001) to also exercise the + CPU-fallback path. +5. **Node-count instrumentation.** Log `prog.nodes.len()` per table to size GPU + scratch and confirm hash-consing is effective (risk 4/5/6). + +--- + +## Appendix — full constraint enumeration (verified by reading every body) + +**33 algebraic `impl TransitionConstraint` +structs + 2 framework LogUp `TransitionConstraintEvaluator` structs.** Every one +uses ONLY capturable ops: field `+ - *` and negation, `FieldElement::from(u64/ +i64)`, `one()`/`zero()`, `.clone()`, `get_main_evaluation_element`/ +`get_aux_evaluation_element`, and `to_extension()`. **Zero** uses of `.inv()`, +`.pow()`, `/`, `.sqrt()`, field-value conditionals, or data-dependent loops. +Every conditional branches on **metadata** (`carry_idx`, `is_sender`, kind +enums), never on a field value. Helper fns (`carry_chain`, +`compute_multiplicity_from_step`, `compute_fingerprint_from_step`) contain only +statically-bounded loops. → For SymField this confirms the `IsField` impl needs +only add/sub/mul/neg/from(u64)/one/zero (+ `to_extension`/`embed`) to be +functional; `inv`/`pow`/`div`/real-`ByteConversion` can be `unimplemented!()` +stubs because no body invokes them. + +Algebraic structs (file:line): +- prover/src/constraints/cpu.rs: ProductZeroConstraint:96, Arg2ExclusiveConstraint:132, + MemFlagsBitConstraint:168, RegNotReadIsZeroConstraint:211, Arg2Constraint:266, + RvdEqResConstraint:331, BranchRvdConstraint:422, BranchCondConstraint:464, + NextPcAddConstraint:546 +- prover/src/constraints/templates.rs: IsBitConstraint:80, AddConstraint:470 + (AddOperand/AddLinearTerm with i64 coeffs → `from(i64)` Const nodes) +- prover/src/tables/: BranchConstraint(branch.rs:519), CommitConstraint(commit.rs:837), + Cpu32Constraint(cpu32.rs:645), DvrmConstraint(dvrm.rs:1219), EqXorConstraint(eq.rs:262), + MulZeroConstraint(ec_scalar.rs:291), ConvCarry(ecdas.rs:363), ColIsZero(ecdas.rs:402), + MulZero(ecdas.rs:426), ConvCarry(ecsm.rs:663), ColIsZero(ecsm.rs:698), + CarryBit(ecsm.rs:791), OverflowRequired(ecsm.rs:816), + KeccakAddressNoOverflowConstraint(keccak.rs:503), LoadConstraint(load.rs:572), + LtConstraint(lt.rs:536), MemwConstraint(memw.rs:921), MemwAlignedConstraint(memw_aligned.rs:708), + MemwRegisterMuSumIsBit(memw_register.rs:388), MulConstraint(mul.rs:847), + ShiftConstraint(shift.rs:914), StoreConstraint(store.rs:282) + +LogUp framework (lookup.rs): LookupBatchedTermConstraint:1741 +(`c·fp_a·fp_b − sign_a·m_a·fp_b − sign_b·m_b·fp_a`, degree 3), +LookupAccumulatedConstraint:1868 (running sum over acc col at row 0 AND row 1, +1–2 absorbed interactions, degree 2–3). + +**Multi-kind dispatch structs — IMPORTANT for the IR.** Several "structs" are +really one type that, via a kind-enum matched on **metadata at capture time**, +implements many distinct constraint kinds; each kind must capture to **its own +IR root** (the capture pass iterates the boxed objects, and each boxed object's +`constraint_idx()` already gives it a distinct root slot, so this falls out +naturally — but the plan's `roots` count is driven by `num_transition_constraints`, +not by the 33 struct count): ShiftConstraint(7 kinds), Cpu32Constraint(8), +LtConstraint(6), LoadConstraint(6), MulConstraint(6), DvrmConstraint(11), +BranchConstraint(5), MemwConstraint(3), MemwAlignedConstraint(3), +StoreConstraint(2). The total transition-constraint *count* (and thus root +count) is therefore well above 33; the IR's `roots` vector is sized by +`air.num_transition_constraints()` (traits.rs:286), which the capture pass +already respects by writing `roots[constraint_idx()]`. diff --git a/thoughts/gpu-constraint-eval/roadmap.md b/thoughts/gpu-constraint-eval/roadmap.md new file mode 100644 index 000000000..a59b6dc0c --- /dev/null +++ b/thoughts/gpu-constraint-eval/roadmap.md @@ -0,0 +1,164 @@ +# GPU constraint evaluation — implementation status & execution plan + +**Handoff doc.** Self-contained enough to continue without the originating discussion. +Describes the code as currently built, the decisions already made, and a detailed +checkbox plan to take it to a working, GPU-validated constraint evaluator. + +> **Chosen capture front-end: Plan B (explicit `IrBuilder` + per-constraint `capture()`).** +> Two spikes were built to compare: Plan A (symbolic field) = PR #737 / branch +> `spike/constraint-ir-symfield`; Plan B (builder) = PR #739 / branch +> `spike/constraint-ir-builder`. Both pass the same bit-for-bit diff test and reuse the +> same IR + interpreter. **Plan B is the production direction** (cleaner end-state — no fake +> `IsField`, no thread-local arena, explicit/auditable). Plan A remains as PR #737 for +> reference / comparison. + +--- + +## Goal & motivation + +Evaluate STARK **transition constraints on the GPU**, end-to-end, producing the +composition-polynomial evaluations **on-device**. The point is **data residency**, not +constraint-eval speed (constraints are not the prover bottleneck): once LDE/Merkle/FRI run +on the GPU, evaluating constraints on the CPU forces a D2H round-trip of the (large) LDE +trace, which dominates. Keeping eval on-device removes that transfer. + +## Architecture (decided) + +Capture each table's constraints **once** into a flat, single-field **Goldilocks IR** +(typed `Dim1`=base `u64` / `Dim3`=degree-3 extension `[u64;3]` op-DAG), then **interpret** +that IR on CPU and GPU. One universal kernel; the per-table difference is data. Modeled on +OpenVM's `cuda-backend` (cloned at `others/openvm-stark-backend`, the closest reference — +same FRI-STARK / LDE-quotient protocol; better-matched than SP1). + +### Decisions already made (don't relitigate without reason) +- **Capture front-end = Plan B (explicit builder).** Each constraint implements an + object-safe `Capture { fn capture(&self, &mut IrBuilder) }`, translating its `evaluate` + body into builder calls (`main`/`aux`/`add`/`sub`/`mul`/`neg`/`const_*`/`emit`). No fake + field, no arena, explicit and auditable. (Plan A — a recording "symbolic field" that + captures with zero body edits — was spiked first to validate the IR/interpreter cheaply; + kept as PR #737 + `plan-symbolic-field.md` for reference. We chose B for the cleaner + production end-state.) +- **Backend = interpreter, not codegen** for v1. Codegen stays available later from the same IR. +- **GPU value array = global memory, no register allocation** to start (simplest, works for + all program sizes). Add register allocation only if profiling needs it (Phase 6). +- **Keep the existing boxed CPU path** as the default + differential oracle behind a toggle + (the `capture()` methods are added alongside `evaluate`, which stays). +- **Device field arithmetic already exists** — reuse `crypto/math-cuda/kernels/ext3.cuh` + (`ext3::{add,sub,mul,mul_base}`, where `mul_base` = base×ext) and `kernels/goldilocks.cuh`. + Do **not** build new field math. + +--- + +## Phase 0 — CPU spikes ✅ DONE (two draft PRs; Plan B is the production base) + +Both spikes build, are fmt/clippy clean, and pass a bit-for-bit diff test (capture → +interpret == real `evaluate`, 1000 random rows) for `IsBit`/`Add`/`ProductZero`. They cover +**base-field algebraic constraints only**, single step (offset 0, row 0), main columns only +— no aux, no next-row, no LogUp, no uniforms, not wired into the prover, no GPU. + +**Shared (identical in both):** the IR and the CPU interpreter. +- `ir.rs` — `enum Dim { D1, D3 }`; `enum Op { Const1(u64), Const3([u64;3]), Var { main: bool, offset: u8, row: u8, col: u16 }, Add(u32,u32), Sub(u32,u32), Mul(u32,u32), Neg(u32), Embed(u32) }`; `struct ConstraintProgram { nodes: Vec, dims: Vec, roots: Vec }`. Typing: `(D1,D1)->D1`, any `D3` operand -> `D3` (auto-embed); `Embed: D1->D3`. +- `interp.rs` — `eval_program_base(prog, main_row) -> FieldElement`: forward pass over nodes into a `Value { D1 | D3 }` array, reusing real `FieldElement` arithmetic; resolves `Var{col}` from the row. + +**Plan B — the production base (PR #739, branch `spike/constraint-ir-builder`).** Module +`crypto/stark/src/constraint_ir/`: +- `ir.rs`, `interp.rs` — the shared IR + interpreter above (reused verbatim). +- `builder.rs` — `IrBuilder` (hash-conses nodes on `(Op, Dim)`, dedups base constants by value, dim-join `(D1,D1)->D1` else `D3`, reserves id 0 = `Const1(0)`) + `Expr { id, dim }`. Methods: `main(offset,col)`/`aux(offset,col)`, `const_base`/`const_signed`/`one`, `add`/`sub`/`mul`/`neg`, `emit(constraint_idx, e)`, `finish() -> ConstraintProgram`. +- `mod.rs` — object-safe `pub trait Capture { fn capture(&self, &mut IrBuilder); }`. +- Constraint impls (added **alongside** the unchanged `evaluate`, non-destructive): `IsBitConstraint`, `AddConstraint` (incl. `AddOperand`/`AddLinearTerm` lo/hi-limb mapping with i64 coeffs + the `inv_2_32` constant), `ProductZeroConstraint`. +- `prover/src/tests/constraint_ir_tests.rs` — the diff test. Node counts: product_zero **4**, is_bit_uncond **5**, is_bit_cond **7**, add_carry_0 **14**, add_carry_1 **21** (minimal — the builder only emits leaves for columns actually read). +- Run: `cargo test -p lambda-vm-prover constraint_ir_tests -- --nocapture` + +**Plan A — reference only (PR #737, branch `spike/constraint-ir-symfield`).** Module +`crypto/stark/src/symbolic/` (`sym_field.rs` recording field + capture). Retired the +"can a symbolic type satisfy `IsField`?" question (yes — needs only `IsField` + +`IsSubFieldOf`; capture never builds `AIR`). Not the production path. + +--- + +## Phase 1 — Full Plan-B capture coverage (all constraints, prover + verifier) + +Goal: implement `Capture` for **every** constraint of a real table (all ~33 algebraic + the +2 LogUp), for both prover and verifier, validated on CPU. The GPU runs this same IR, so +completeness/correctness must be nailed here first. + +- [ ] **Extend the IR** (`constraint_ir/ir.rs`): add leaf `Op` variants for the per-proof/per-row uniforms — `Periodic { idx }` (D1), `RapChallenge { idx }` (D3), `AlphaPow { idx }` (D3), `TableOffset` (D3), `Shift { which: u8 }` (D1). `Op::Var` already carries `offset`/`row`/`main` for next-row + aux reads. +- [ ] **Extend `IrBuilder`** (`constraint_ir/builder.rs`): add leaf constructors for the uniforms (`challenge`, `alpha_power`, `periodic`, `table_offset`, `shift`) and `const_ext([u64;3])`; ensure `aux(offset, col)` supports `offset=1` (next row). Make `emit` index `roots` by `constraint_idx` (the spike stores in emit order — switch to indexed for the full per-table program). +- [ ] **`Capture` for the remaining ~30 algebraic constraints** — mechanical translation of each `evaluate`/`compute` body to builder calls. Files: `prover/src/constraints/cpu.rs` (Arg2Exclusive, MemFlagsBit, RegNotReadIsZero, Arg2, RvdEqRes, BranchRvd, BranchCond, NextPcAdd), `prover/src/tables/{branch,commit,cpu32,dvrm,eq,ec_scalar,ecdas,ecsm,keccak,load,lt,memw,memw_aligned,memw_register,mul,shift,store}.rs`. The multi-kind ones (Dvrm 11 / Cpu32 8 / Shift 7 / Lt·Load·Mul 6) are the bulk — their `compute()` loops are statically bounded, so they unroll into builder calls at capture time. +- [ ] **`Capture` for the 2 LogUp constraints (the crux)** — `LookupBatchedTermConstraint` and `LookupAccumulatedConstraint` (`crypto/stark/src/lookup.rs`). Translate their bodies to builder calls: fingerprint = `challenge − Σ alpha_power·col` (mirror `compute_fingerprint_from_step`/`Packing::accumulate_fingerprint_with`), multiplicity (mirror `Multiplicity::evaluate_with`), `is_sender` as a compile-time `add` vs `neg`, the `c·fp_a·fp_b − …` / accumulated formulas. The accumulated one reads aux at offset 0 **and** 1 → use `aux(1, col)`. This is more work than Plan A's auto-capture (Plan A's inner fns were already field-generic) but is explicit and lives in one place. +- [ ] **Per-table program** (`crypto/stark/src/traits.rs`): `fn constraint_program(&self) -> ConstraintProgram` — iterate `self.transition_constraints()`, call `capture` on each into one `IrBuilder`, `roots[constraint_idx()]`, `num_base = num_base_transition_constraints()`. (Requires the object-safe `Capture` to be reachable from the boxed `TransitionConstraintEvaluator` — add `capture` to that trait, which is object-safe and matches the production design.) +- [ ] **Full interpreter** (`constraint_ir/interp.rs`): generalize to `eval_program(prog, inputs) -> (base: Vec>, ext: Vec>)` matching the `compute_transition_prover` contract — resolve all leaf kinds + offsets + aux; Dim1/Dim3 with auto-embed; add a verifier entry (all-D3 frame at the OOD point). +- [ ] **Acceptance test:** for the CPU table + ≥1 LogUp-heavy table, capture the full program, interpret per-row over a real LDE, and `assert_eq!` against `air.compute_transition_prover(...)` bit-for-bit; same for the verifier vs `air.compute_transition(...)` at the OOD point. + +--- + +## Phase 2 — Wire interpreter into prover/verifier (CPU), behind a toggle + +- [ ] Add a `constraint-ir` Cargo feature (or runtime env toggle) in `crypto/stark/Cargo.toml`. +- [ ] Cache the `ConstraintProgram` once in `ConstraintEvaluator::new` (`crypto/stark/src/constraints/evaluator.rs`). +- [ ] In `evaluate_transitions` (same file), behind the toggle, replace the `air.compute_transition_prover(&ctx, base_buf, transition_buf)` call (~line 100) with the IR interpreter; keep the boxed path as default + oracle. Leave the `Σ βᵢ·Cᵢ·Zᵢ⁻¹ + boundary` accumulation untouched. +- [ ] Verifier: same swap at `crypto/stark/src/verifier.rs` (`air.compute_transition`). +- [ ] **Acceptance:** full prove→verify suite passes with the toggle ON, across all tables — `cargo test --release -p lambda-vm-prover` (incl. `test_prove_elfs_*`). This is the **CPU end-to-end** checkpoint; the IR is now proven complete and correct independent of GPU. + +--- + +## Phase 3 — Device field primitives ✅ ALREADY EXIST + +Reuse `crypto/math-cuda/kernels/ext3.cuh` (`ext3::Fe3`, `ext3::{add,sub,mul,mul_base}`) and +`kernels/goldilocks.cuh`. Already used by the GPU FRI/inverse/barycentric/deep kernels. +Only remaining: confirm a `neg` (else `ext3::sub(zero, x)`) and include the header — do it +as part of Phase 4. + +--- + +## Phase 4 — GPU interpreter kernel + +Start **stripped** (mirror OpenVM's `GLOBAL=true` kernel: global-memory value array, no +register allocation, no bit-packed codec). Reference: `others/openvm-stark-backend/crates/cuda-backend/cuda/src/quotient.cu` (`cukernel_quotient`) and `cuda/include/codec.cuh`. + +- [ ] **Device IR layout** (`crypto/stark/src/constraint_ir/` + `crypto/math-cuda`): serialize `ConstraintProgram` to a `#[repr(C)]` flat node array (`{ op_tag: u32, a: u32, b: u32, dim: u32 }`) + a constants table + `roots` + `num_base`. Plus per-proof uniform device buffers (rap challenges, alpha powers, table_offset, periodic columns, shift consts). +- [ ] **Kernel** (`crypto/math-cuda/kernels/constraint_interp.cu` + Rust wrapper in `crypto/math-cuda/src/`): one thread per LDE row (tiled). Forward pass over the node array into a **per-thread value array in global memory** (one slot per node, strided per thread). Resolve `Var{main/aux, offset, col}` from the device-resident LDE columns (`GpuLdeBase`/`GpuLdeExt3` keep-handles from `trace.gpu_main()`/`gpu_aux()`). Dim1 ops via `goldilocks.cuh`, Dim3 via `ext3.cuh` (`mul_base` for D1×D3). **Fused accumulation:** Horner `acc = acc*alpha + Cᵢ` over the constraint roots, then `acc *= inv_zeroifier[row]` → write the composition-poly evaluation. Output stays on device. +- [ ] **Host dispatch** (`crypto/stark/src/constraint_ir/gpu_interp.rs`): `try_eval_program_gpu(...) -> Option<...>` gated on `TypeId::of::() == GoldilocksField && TypeId::of::() == Degree3GoldilocksExtensionField` + a size threshold (mirror `crypto/stark/src/gpu_lde.rs:119-152`). Upload program + uniforms once; launch; leave output device-resident. Fall back to the CPU interpreter / boxed path otherwise. +- [ ] **Pipeline integration:** add a whole-table GPU entry (e.g. `AIR::compute_transitions_batched(lde) -> Option` tried by `evaluate_transitions` before the per-row loop) so the composition-poly evals are produced on-device and feed the existing GPU Merkle commit with **no D2H of the `Cᵢ` matrix**. Reconcile zerofier/boundary accounting with the CPU semantics. +- [ ] **Acceptance:** compiles under `cargo build -p lambda-vm-prover --features cuda`. + +--- + +## Phase 5 — "Working on GPU" (the deliverable) — runs on the CUDA machine + +- [ ] **GPU↔CPU parity test** (extend `prover/tests/cuda_path_integration.rs` / `cuda_fallback_tests.rs`): composition-poly evals on GPU == CPU interpreter == boxed path, per table, on real traces. +- [ ] **End-to-end GPU prove→verify** on a real ELF with `--features cuda`. A passing verify is the goal. +- [ ] **Benchmark** (bench server): prove time with GPU constraints vs CPU constraints — confirm the data-residency win (no LDE D2H for constraint eval). + +--- + +## Phase 6 — Optimizations (only if a profile demands) + +- [ ] **Register allocation** — port OpenVM's transpiler liveness + linear-scan (`others/openvm-stark-backend/crates/cuda-backend/src/transpiler/mod.rs`) to shrink the per-thread value array (local `FpExt[N]` for small programs, smaller global buffer for large ones like Dvrm/Shift/ecsm). +- [ ] **DCE / const-fold peephole** — fold `×Const(0)`/`+Const(0)`; drop dead nodes. +- [ ] **Bit-packed codec** — only if H2D bandwidth shows up (unlikely; the rule stream is tiny and uploaded once). +- [ ] **Selective codegen** — given few-but-large tables, codegen the 1–3 hottest tables (nvcc does register allocation, no per-op dispatch) if interpreter overhead is material. Hybrid: interpreter baseline + codegen the hot ones. + +--- + +## Gotchas / invariants + +- **Single field:** Goldilocks base + degree-3 extension only. The IR's `Dim1`/`Dim3` and + the `ext3.cuh` primitives cover everything. +- **Object safety:** generic methods can't live on `Box`. + Plan B's `Capture` trait is **non-generic** (concrete `IrBuilder`), so it's object-safe; + capture runs once at setup, the per-row hot path only interprets the (data) IR. +- **Don't D2H the `Cᵢ` matrix:** fuse the accumulation in the GPU kernel so only the + (small) composition-poly evaluation crosses on-device into Merkle. +- **LDE columns are already device-resident** (`GpuLdeBase`/`GpuLdeExt3`); read them in place. +- *(Plan A only, not B:)* the symbolic-field path needed `eq → false` to defeat the runtime + zero-skip during capture. Plan B has no such hack — it emits exactly what `capture` writes. + +## Reference material (in-repo) + +- `others/openvm-stark-backend/crates/cuda-backend/` — `src/transpiler/{mod.rs,codec.rs}`, `src/quotient/`, `cuda/src/quotient.cu`, `cuda/include/codec.cuh`. The closest working reference (BabyBear; for Goldilocks the only deltas are 64-bit constants needing a side table, degree-3 ext, and they run all-FpExt with no base/ext split). +- `crypto/stark/src/gpu_lde.rs` — the TypeId+transmute generic→concrete-Goldilocks GPU seam to mirror. +- `thoughts/gpu-constraint-eval/plan-builder-rewrite.md` — the full Plan B design (the chosen approach; Phases 1+ detail its remaining sections). +- `thoughts/gpu-constraint-eval/plan-symbolic-field.md` — Plan A (the reference/comparison spike, PR #737). +- `thoughts/gpu-constraint-eval/README.md` — motivation + the SP1/OpenVM/zisk survey. +- PRs: **#739** (Plan B, production base) · **#737** (Plan A, reference). From fa3ffa920f2c9ae4a07b76b6a5cf98a72ec5b2f2 Mon Sep 17 00:00:00 2001 From: MauroFab Date: Tue, 30 Jun 2026 12:49:23 -0300 Subject: [PATCH 3/6] docs(gpu-constraint-eval): add airbender reference (validates interpreter-for-quotient; device-layout tricks) --- thoughts/gpu-constraint-eval/roadmap.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thoughts/gpu-constraint-eval/roadmap.md b/thoughts/gpu-constraint-eval/roadmap.md index a59b6dc0c..5f54491bd 100644 --- a/thoughts/gpu-constraint-eval/roadmap.md +++ b/thoughts/gpu-constraint-eval/roadmap.md @@ -116,7 +116,7 @@ as part of Phase 4. Start **stripped** (mirror OpenVM's `GLOBAL=true` kernel: global-memory value array, no register allocation, no bit-packed codec). Reference: `others/openvm-stark-backend/crates/cuda-backend/cuda/src/quotient.cu` (`cukernel_quotient`) and `cuda/include/codec.cuh`. -- [ ] **Device IR layout** (`crypto/stark/src/constraint_ir/` + `crypto/math-cuda`): serialize `ConstraintProgram` to a `#[repr(C)]` flat node array (`{ op_tag: u32, a: u32, b: u32, dim: u32 }`) + a constants table + `roots` + `num_base`. Plus per-proof uniform device buffers (rap challenges, alpha powers, table_offset, periodic columns, shift consts). +- [ ] **Device IR layout** (`crypto/stark/src/constraint_ir/` + `crypto/math-cuda`): serialize `ConstraintProgram` to a `#[repr(C)]` flat node array (`{ op_tag: u32, a: u32, b: u32, dim: u32 }`) + a constants table + `roots` + `num_base`. Plus per-proof uniform device buffers (rap challenges, alpha powers, table_offset, periodic columns, shift consts). If the program is small, pack the metadata into constant memory / a `__grid_constant__` proxy as airbender does (`<8 KB`, 1-byte coeff tags) — see Reference. - [ ] **Kernel** (`crypto/math-cuda/kernels/constraint_interp.cu` + Rust wrapper in `crypto/math-cuda/src/`): one thread per LDE row (tiled). Forward pass over the node array into a **per-thread value array in global memory** (one slot per node, strided per thread). Resolve `Var{main/aux, offset, col}` from the device-resident LDE columns (`GpuLdeBase`/`GpuLdeExt3` keep-handles from `trace.gpu_main()`/`gpu_aux()`). Dim1 ops via `goldilocks.cuh`, Dim3 via `ext3.cuh` (`mul_base` for D1×D3). **Fused accumulation:** Horner `acc = acc*alpha + Cᵢ` over the constraint roots, then `acc *= inv_zeroifier[row]` → write the composition-poly evaluation. Output stays on device. - [ ] **Host dispatch** (`crypto/stark/src/constraint_ir/gpu_interp.rs`): `try_eval_program_gpu(...) -> Option<...>` gated on `TypeId::of::() == GoldilocksField && TypeId::of::() == Degree3GoldilocksExtensionField` + a size threshold (mirror `crypto/stark/src/gpu_lde.rs:119-152`). Upload program + uniforms once; launch; leave output device-resident. Fall back to the CPU interpreter / boxed path otherwise. - [ ] **Pipeline integration:** add a whole-table GPU entry (e.g. `AIR::compute_transitions_batched(lde) -> Option` tried by `evaluate_transitions` before the per-row loop) so the composition-poly evals are produced on-device and feed the existing GPU Merkle commit with **no D2H of the `Cᵢ` matrix**. Reconcile zerofier/boundary accounting with the CPU semantics. @@ -157,6 +157,7 @@ register allocation, no bit-packed codec). Reference: `others/openvm-stark-backe ## Reference material (in-repo) - `others/openvm-stark-backend/crates/cuda-backend/` — `src/transpiler/{mod.rs,codec.rs}`, `src/quotient/`, `cuda/src/quotient.cu`, `cuda/include/codec.cuh`. The closest working reference (BabyBear; for Goldilocks the only deltas are 64-bit constants needing a side table, degree-3 ext, and they run all-FpExt with no base/ext split). +- `others/airbender/` (Matter Labs zksync-airbender, reference clone — **don't commit**) — a GPU-heavy prover in our **exact protocol family** (LDE-coset + FRI + quotient-over-divisors), field Mersenne31 + quartic (not Goldilocks). Independently chose a **generic interpreter kernel** for quotient/constraint eval (`gpu_prover/native/stage3.cu` `ab_generic_constraints_kernel` over `FlattenedGenericConstraintsMetadata`; Rust bridge in `gpu_prover/src/prover/stage_3_kernels.rs`), and reserved **codegen only for the irregular witness path** (`gpu_witness_eval_generator/`, a macro-DSL CPP-expanded to a straight-line kernel) — so it validates interpreter-for-quotient. **Steal:** the flat encoded constraint IR packed **<8 KB into constant memory** (`__grid_constant__`) with 1-byte ±1/explicit **coeff tags** + per-constraint term counts; the **column-major `ptr+stride`, one-thread-per-row** device layout with the whole working set bundled in one by-value proxy. **Caveats:** M31/quartic field + its coset/decompression math don't port to Goldilocks; their encoding assumes **degree ≤ 2** (ours is a general op-DAG up to degree 3 — so OpenVM's three-address IR is the closer template for the node walk; borrow airbender's *packing/layout* tricks, not its term encoding). - `crypto/stark/src/gpu_lde.rs` — the TypeId+transmute generic→concrete-Goldilocks GPU seam to mirror. - `thoughts/gpu-constraint-eval/plan-builder-rewrite.md` — the full Plan B design (the chosen approach; Phases 1+ detail its remaining sections). - `thoughts/gpu-constraint-eval/plan-symbolic-field.md` — Plan A (the reference/comparison spike, PR #737). From ea6f6b4d2af3dfe1274f3fbe4e59f0d0c10ee66d Mon Sep 17 00:00:00 2001 From: MauroFab Date: Tue, 30 Jun 2026 16:34:48 -0300 Subject: [PATCH 4/6] =?UTF-8?q?spike(stark):=20Phase=201=20=E2=80=94=20ful?= =?UTF-8?q?l=20constraint-ir=20capture=20coverage=20(incl.=20LogUp)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the constraint-ir spike (PR #739) from the minimal Phase-0 algebraic subset to full coverage: all ~33 algebraic transition constraints across prover/src/constraints/ and prover/src/tables/, plus the 2 LogUp framework constraints (LookupBatchedTermConstraint, LookupAccumulatedConstraint) in crypto/stark/src/lookup.rs. IR/builder/interpreter additions: - ir.rs: Periodic/RapChallenge/AlphaPow/TableOffset leaf Ops; ConstraintProgram gains num_base for the prover's base/ext split. - builder.rs: challenge/alpha_power/periodic/table_offset/const_ext leaf constructors; emit() now indexes roots by constraint_idx (was emit-order). - interp.rs: eval_program/eval_program_verifier, the full prover/verifier entry points matching compute_transition_prover/compute_transition; eval_program_base kept for the Phase-0 diff test. - traits.rs: AIR::constraint_program(), capturing every transition_constraints() entry into one program. - constraints/transition.rs: object-safe capture() on TransitionConstraintEvaluator (default unimplemented!, since most non-production examples don't need it); TransitionConstraintAdapter forwards to the wrapped Capture impl; boxed() requires Self: Capture so every production constraint must implement it. Capture impls added for every remaining algebraic constraint struct (cpu.rs, branch/commit/eq/load/store/memw*/keccak/ec_scalar/ecsm/ecdas, and the multi-kind mega-constraints mul/dvrm/shift/cpu32/lt), mirroring each evaluate()/compute() body 1:1. LogUp capture mirrors compute_fingerprint_from_step /Packing::accumulate_fingerprint_with/Multiplicity::evaluate_with, with is_sender resolved at capture time (add vs neg) and the accumulated constraint's next-row reads via aux(1, col). Gate: prover/src/tests/constraint_ir_tests.rs now captures the full CPU and EQ table programs and asserts bit-for-bit equality against compute_transition_prover (every LDE row) and compute_transition (verifier, OOD point) — covering both the batched-pair and both 1-/2-absorbed LogUp branches. Full prover test suite (430 tests, incl. test_prove_elfs_*) passes unchanged, confirming the additions are non-destructive. --- crypto/stark/src/constraint_ir/builder.rs | 52 ++- crypto/stark/src/constraint_ir/interp.rs | 178 +++++++++- crypto/stark/src/constraint_ir/ir.rs | 16 +- crypto/stark/src/constraint_ir/mod.rs | 2 +- crypto/stark/src/constraints/transition.rs | 36 +- crypto/stark/src/lookup.rs | 394 +++++++++++++++++++++ crypto/stark/src/traits.rs | 20 ++ prover/src/constraints/cpu.rs | 212 ++++++++++- prover/src/tables/branch.rs | 101 ++++++ prover/src/tables/commit.rs | 21 ++ prover/src/tables/cpu32.rs | 96 +++++ prover/src/tables/dvrm.rs | 201 +++++++++++ prover/src/tables/ec_scalar.rs | 16 + prover/src/tables/ecdas.rs | 183 ++++++++++ prover/src/tables/ecsm.rs | 159 +++++++++ prover/src/tables/eq.rs | 19 + prover/src/tables/keccak.rs | 64 ++++ prover/src/tables/load.rs | 70 ++++ prover/src/tables/lt.rs | 119 +++++++ prover/src/tables/memw.rs | 44 +++ prover/src/tables/memw_aligned.rs | 37 ++ prover/src/tables/memw_register.rs | 14 + prover/src/tables/mul.rs | 97 +++++ prover/src/tables/shift.rs | 166 +++++++++ prover/src/tables/store.rs | 26 ++ prover/src/test_utils.rs | 2 +- prover/src/tests/constraint_ir_tests.rs | 257 +++++++++++++- 27 files changed, 2573 insertions(+), 29 deletions(-) diff --git a/crypto/stark/src/constraint_ir/builder.rs b/crypto/stark/src/constraint_ir/builder.rs index 29328d9b2..9ed7da4e6 100644 --- a/crypto/stark/src/constraint_ir/builder.rs +++ b/crypto/stark/src/constraint_ir/builder.rs @@ -113,6 +113,26 @@ impl IrBuilder { ) } + /// A periodic column read at the current row (`D1`). + pub fn periodic(&mut self, idx: usize) -> Expr { + self.push(Op::Periodic { idx: idx as u16 }, Dim::D1) + } + + /// A LogUp RAP challenge, uniform per proof (`D3`). + pub fn challenge(&mut self, idx: usize) -> Expr { + self.push(Op::RapChallenge { idx: idx as u16 }, Dim::D3) + } + + /// A precomputed LogUp alpha power, uniform per proof (`D3`). + pub fn alpha_power(&mut self, idx: usize) -> Expr { + self.push(Op::AlphaPow { idx: idx as u16 }, Dim::D3) + } + + /// The LogUp table offset `L/N`, uniform per proof (`D3`). + pub fn table_offset(&mut self) -> Expr { + self.push(Op::TableOffset, Dim::D3) + } + // --------------------------------------------------------------------- // Constants // --------------------------------------------------------------------- @@ -139,6 +159,16 @@ impl IrBuilder { e } + /// An extension-field constant `[c0, c1, c2]`, each component reduced. + /// Dedup is via the general `(Op, Dim)` hash-cons (`push`); no separate + /// cache is needed since `Const3` is `Eq + Hash`. + pub fn const_ext(&mut self, v: [u64; 3]) -> Expr { + let c0 = *FieldElement::::from(v[0]).value(); + let c1 = *FieldElement::::from(v[1]).value(); + let c2 = *FieldElement::::from(v[2]).value(); + self.push(Op::Const3([c0, c1, c2]), Dim::D3) + } + /// The base-field constant `1`. pub fn one(&mut self) -> Expr { self.const_base(1) @@ -185,19 +215,29 @@ impl IrBuilder { /// Record `e` as the root for constraint `constraint_idx`. /// - /// Roots are stored in emit order; the minimal spike emits exactly one root - /// per program, so `constraint_idx` is accepted for parity with the - /// production design but not used to index `roots` here. - pub fn emit(&mut self, _constraint_idx: usize, e: Expr) { - self.roots.push(e.id); + /// `roots` is indexed by `constraint_idx` (grown/filled with sentinel `0` + /// as needed), so constraints can be captured in any order and a full + /// per-table program (one `emit` per `TransitionConstraintEvaluator` in + /// `transition_constraints()`) ends up with `roots[c]` = constraint `c`'s + /// value, matching `AIR::num_transition_constraints()` indexing. + pub fn emit(&mut self, constraint_idx: usize, e: Expr) { + if self.roots.len() <= constraint_idx { + self.roots.resize(constraint_idx + 1, 0); + } + self.roots[constraint_idx] = e.id; } /// Consume the builder and produce the captured program. - pub fn finish(self) -> ConstraintProgram { + /// + /// `num_base` is the number of leading (by `constraint_idx`) constraints + /// that are base-field (`D1`) rooted, matching + /// `AIR::num_base_transition_constraints()`. + pub fn finish(self, num_base: usize) -> ConstraintProgram { ConstraintProgram { nodes: self.nodes, dims: self.dims, roots: self.roots, + num_base, } } } diff --git a/crypto/stark/src/constraint_ir/interp.rs b/crypto/stark/src/constraint_ir/interp.rs index 62e502594..4df7f9008 100644 --- a/crypto/stark/src/constraint_ir/interp.rs +++ b/crypto/stark/src/constraint_ir/interp.rs @@ -5,12 +5,20 @@ //! `FieldElement` arithmetic so per-op results are bit-identical to the boxed //! constraint path. Mixed-dimension ops auto-embed the `D1` operand into `D3`, //! mirroring the field tower's `F: IsSubFieldOf` arithmetic. +//! +//! [`eval_program`] / [`eval_program_verifier`] are the full entry points, +//! matching `AIR::compute_transition_prover` / `AIR::compute_transition` +//! respectively. [`eval_program_base`] is the minimal Phase-0 entry point +//! (single root, main-only, base-field result) kept for the original spike +//! diff test. use math::field::element::FieldElement; use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField as GoldilocksExtension; use math::field::goldilocks::GoldilocksField; use super::ir::{ConstraintProgram, Dim, Op}; +use crate::table::TableView; +use crate::traits::TransitionEvaluationContext; type Fp = FieldElement; type Fp3 = FieldElement; @@ -41,12 +49,26 @@ impl Value { } } -/// Evaluate the program's single root over a base-field main row. -/// -/// `main_row[col]` resolves `Var { main: true, col, .. }` leaves. The minimal -/// algebraic constraint set only reads main columns at offset 0, row 0 and -/// returns a base-field (`D1`) value, so this returns a `FieldElement`. -pub fn eval_program_base(prog: &ConstraintProgram, main_row: &[Fp]) -> Fp { +/// Shared forward pass: evaluate every node, then return the value array. +/// `resolve_var` resolves `Op::Var` leaves; `resolve_periodic` resolves +/// `Op::Periodic`; the rest of the uniforms are read directly from `inputs`- +/// agnostic closures so prover/verifier share this one walk. +#[allow(clippy::too_many_arguments)] +fn run( + prog: &ConstraintProgram, + resolve_var: FVar, + resolve_periodic: FPeriodic, + resolve_challenge: FChallenge, + resolve_alpha: FAlpha, + resolve_offset: FOffset, +) -> Vec +where + FVar: Fn(bool, u8, u8, u16) -> Value, + FPeriodic: Fn(u16) -> Value, + FChallenge: Fn(u16) -> Fp3, + FAlpha: Fn(u16) -> Fp3, + FOffset: Fn() -> Fp3, +{ let mut values: Vec = Vec::with_capacity(prog.nodes.len()); for (i, op) in prog.nodes.iter().enumerate() { @@ -55,11 +77,16 @@ pub fn eval_program_base(prog: &ConstraintProgram, main_row: &[Fp]) -> Fp { Op::Const3([c0, c1, c2]) => { Value::D3(Fp3::from_raw([Fp::from(c0), Fp::from(c1), Fp::from(c2)])) } - Op::Var { main, row, col, .. } => { - assert!(main, "aux leaves are not part of the minimal algebraic set"); - assert_eq!(row, 0, "minimal set reads row 0 only"); - Value::D1(main_row[col as usize]) - } + Op::Var { + main, + offset, + row, + col, + } => resolve_var(main, offset, row, col), + Op::Periodic { idx } => resolve_periodic(idx), + Op::RapChallenge { idx } => Value::D3(resolve_challenge(idx)), + Op::AlphaPow { idx } => Value::D3(resolve_alpha(idx)), + Op::TableOffset => Value::D3(resolve_offset()), Op::Add(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x + y, |x, y| x + y), Op::Sub(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x - y, |x, y| x - y), Op::Mul(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x * y, |x, y| x * y), @@ -73,8 +100,7 @@ pub fn eval_program_base(prog: &ConstraintProgram, main_row: &[Fp]) -> Fp { values.push(v); } - let root = prog.roots[0]; - values[root as usize].as_base() + values } /// Apply a binary op, auto-embedding to the extension field when the result @@ -95,3 +121,129 @@ fn binop( _ => Value::D3(ext_op(va.to_ext(), vb.to_ext())), } } + +/// Evaluate one constraint's root over a base-field main row. +/// +/// `main_row[col]` resolves `Var { main: true, col, .. }` leaves. The minimal +/// algebraic constraint set only reads main columns at offset 0, row 0 and +/// returns a base-field (`D1`) value, so this returns a `FieldElement`. +/// `constraint_idx` selects which root to read (a single-constraint capture +/// from the Phase-0 diff test always uses the constraint's own +/// `constraint_idx`, which `IrBuilder::emit` now indexes `roots` by). +/// +/// Kept for the Phase-0 diff test (`prover/src/tests/constraint_ir_tests.rs`); +/// [`eval_program`] is the full prover entry point. +pub fn eval_program_base(prog: &ConstraintProgram, constraint_idx: usize, main_row: &[Fp]) -> Fp { + let values = run( + prog, + |main, _offset, row, col| { + assert!(main, "aux leaves are not part of the minimal algebraic set"); + assert_eq!(row, 0, "minimal set reads row 0 only"); + Value::D1(main_row[col as usize]) + }, + |_idx| panic!("periodic leaves are not part of the minimal algebraic set"), + |_idx| panic!("challenge leaves are not part of the minimal algebraic set"), + |_idx| panic!("alpha_power leaves are not part of the minimal algebraic set"), + || panic!("table_offset leaves are not part of the minimal algebraic set"), + ); + let root = prog.roots[constraint_idx]; + values[root as usize].as_base() +} + +/// Full prover entry point: evaluate every constraint in `prog` against +/// `ctx` (must be [`TransitionEvaluationContext::Prover`]), writing base-field +/// (`D1`-rooted) constraints into `base_evals` and extension-field +/// (`D3`-rooted) constraints into `ext_evals[prog.num_base..]` — the same +/// contract as `AIR::compute_transition_prover`. +pub fn eval_program( + prog: &ConstraintProgram, + ctx: &TransitionEvaluationContext, + base_evals: &mut [FieldElement], + ext_evals: &mut [FieldElement], +) { + let TransitionEvaluationContext::Prover { + frame, + periodic_values, + rap_challenges, + logup_alpha_powers, + logup_table_offset, + .. + } = ctx + else { + unreachable!("eval_program called with a Verifier context"); + }; + + let values = run( + prog, + |main, offset, row, col| { + let step: &TableView = + frame.get_evaluation_step(offset as usize); + debug_assert_eq!(row, 0, "tables read row 0 of each frame step"); + if main { + Value::D1(*step.get_main_evaluation_element(0, col as usize)) + } else { + Value::D3(*step.get_aux_evaluation_element(0, col as usize)) + } + }, + |idx| Value::D1(periodic_values[idx as usize]), + |idx| rap_challenges[idx as usize], + |idx| logup_alpha_powers[idx as usize], + || *(*logup_table_offset), + ); + + for (c, &root) in prog.roots.iter().enumerate() { + let v = values[root as usize]; + if c < prog.num_base { + base_evals[c] = v.as_base(); + } else { + ext_evals[c] = v.to_ext(); + } + } +} + +/// Full verifier entry point: evaluate every constraint in `prog` against +/// `ctx` (must be [`TransitionEvaluationContext::Verifier`]) at the +/// out-of-domain point, writing every constraint (base or LogUp) into +/// `ext_evals` — the same contract as `AIR::compute_transition`. The verifier +/// frame holds only extension-field elements, so `D1`-rooted constraints are +/// embedded into `D3` on write (mirrors `evaluate(..).to_extension()` in +/// [`crate::constraints::transition::TransitionConstraintAdapter`]). +pub fn eval_program_verifier( + prog: &ConstraintProgram, + ctx: &TransitionEvaluationContext, + ext_evals: &mut [FieldElement], +) { + let TransitionEvaluationContext::Verifier { + frame, + periodic_values, + rap_challenges, + logup_alpha_powers, + logup_table_offset, + .. + } = ctx + else { + unreachable!("eval_program_verifier called with a Prover context"); + }; + + let values = run( + prog, + |main, offset, row, col| { + let step: &TableView = + frame.get_evaluation_step(offset as usize); + debug_assert_eq!(row, 0, "tables read row 0 of each frame step"); + if main { + Value::D3(*step.get_main_evaluation_element(0, col as usize)) + } else { + Value::D3(*step.get_aux_evaluation_element(0, col as usize)) + } + }, + |idx| Value::D3(periodic_values[idx as usize]), + |idx| rap_challenges[idx as usize], + |idx| logup_alpha_powers[idx as usize], + || *(*logup_table_offset), + ); + + for (c, &root) in prog.roots.iter().enumerate() { + ext_evals[c] = values[root as usize].to_ext(); + } +} diff --git a/crypto/stark/src/constraint_ir/ir.rs b/crypto/stark/src/constraint_ir/ir.rs index 8d0a3c449..5b6603fba 100644 --- a/crypto/stark/src/constraint_ir/ir.rs +++ b/crypto/stark/src/constraint_ir/ir.rs @@ -40,6 +40,15 @@ pub enum Op { /// Column index. col: u16, }, + /// A periodic column read: `periodic_values[idx]` at the current row (`D1`). + Periodic { idx: u16 }, + /// A LogUp RAP challenge: `rap_challenges[idx]` (`D3`, uniform per proof). + RapChallenge { idx: u16 }, + /// A precomputed LogUp alpha power: `logup_alpha_powers[idx]` (`D3`, uniform + /// per proof). + AlphaPow { idx: u16 }, + /// The LogUp table offset `L/N` (`D3`, uniform per proof). + TableOffset, /// `nodes[a] + nodes[b]`. Add(u32, u32), /// `nodes[a] - nodes[b]`. @@ -63,8 +72,13 @@ pub struct ConstraintProgram { pub nodes: Vec, /// Per-node result dimension, parallel to `nodes`. pub dims: Vec, - /// Per-constraint root node ids. + /// Per-constraint root node ids, indexed by `constraint_idx`. pub roots: Vec, + /// Number of constraints (a prefix of `roots`) that are base-field (`D1`) + /// rooted, matching `AIR::num_base_transition_constraints()`. The prover + /// interpreter writes these into `base_evals`; the rest (LogUp, always + /// `D3`) go into `ext_evals[num_base..]`. + pub num_base: usize, } impl ConstraintProgram { diff --git a/crypto/stark/src/constraint_ir/mod.rs b/crypto/stark/src/constraint_ir/mod.rs index a515ff177..f904d51f1 100644 --- a/crypto/stark/src/constraint_ir/mod.rs +++ b/crypto/stark/src/constraint_ir/mod.rs @@ -24,7 +24,7 @@ pub mod interp; pub mod ir; pub use builder::{Expr, IrBuilder}; -pub use interp::eval_program_base; +pub use interp::{eval_program, eval_program_base, eval_program_verifier}; pub use ir::{ConstraintProgram, Dim, Op}; /// A transition constraint that can record its algebra into an [`IrBuilder`]. diff --git a/crypto/stark/src/constraints/transition.rs b/crypto/stark/src/constraints/transition.rs index 1fe249c4c..4710d6d08 100644 --- a/crypto/stark/src/constraints/transition.rs +++ b/crypto/stark/src/constraints/transition.rs @@ -1,5 +1,6 @@ use core::ops::Div; +use crate::constraint_ir::IrBuilder; use crate::domain::Domain; use crate::traits::TransitionEvaluationContext; use math::field::element::FieldElement; @@ -20,6 +21,26 @@ where /// where N is the total number of transition constraints. fn constraint_idx(&self) -> usize; + /// Translate this constraint's algebra into [`IrBuilder`] nodes, finishing + /// with `builder.emit(self.constraint_idx(), root)`. Object-safe (the + /// builder is concrete, not generic), so it lives directly on this boxed + /// trait — see [`crate::constraint_ir::Capture`] for the user-facing + /// (non-boxed) counterpart that [`super::transition::TransitionConstraintAdapter`] + /// forwards to. + /// + /// Default panics: every production constraint must override this (via + /// `TransitionConstraintAdapter` + `Capture`, or directly for the LogUp + /// framework constraints). The default exists only so the many + /// `examples/` and test-only `TransitionConstraintEvaluator` impls (not + /// part of the IR migration) don't need a body. + fn capture(&self, _builder: &mut IrBuilder) { + unimplemented!( + "TransitionConstraintEvaluator::capture not implemented for this constraint; \ + it is not part of the constraint-ir migration (see crypto/stark/src/examples/ \ + or implement Capture for production constraints)" + ); + } + /// The function representing the evaluation of the constraint over elements /// of the trace table. /// @@ -377,10 +398,12 @@ where /// Wrap into a boxed `TransitionConstraintEvaluator` for the evaluator. /// /// The adapter auto-generates `evaluate_verifier()` and `evaluate_prover()` - /// from the generic `evaluate()`. + /// from the generic `evaluate()`, and forwards `capture()` to `Self`'s + /// `Capture` impl (required so every boxed constraint can be captured into + /// the IR — see `crypto/stark/src/constraint_ir/mod.rs`). fn boxed(self) -> Box> where - Self: Sized + 'static, + Self: Sized + crate::constraint_ir::Capture + 'static, { Box::new(TransitionConstraintAdapter(self)) } @@ -389,12 +412,14 @@ where /// Adapter: implements `TransitionConstraintEvaluator` for any `TransitionConstraint`. /// /// Auto-generates `evaluate_verifier()` (E×E path) and `evaluate_prover()` (F path) -/// from the user's generic `evaluate()`. +/// from the user's generic `evaluate()`, and forwards `capture()` to the wrapped +/// `T: Capture` (every production constraint implements `Capture` alongside +/// `evaluate`; see `crypto/stark/src/constraint_ir/mod.rs`). pub struct TransitionConstraintAdapter(pub T); impl TransitionConstraintEvaluator for TransitionConstraintAdapter where - T: TransitionConstraint + 'static, + T: TransitionConstraint + crate::constraint_ir::Capture + 'static, F: IsSubFieldOf + IsFFTField + Send + Sync, E: IsField + Send + Sync, { @@ -419,6 +444,9 @@ where fn periodic_exemptions_offset(&self) -> Option { self.0.periodic_exemptions_offset() } + fn capture(&self, builder: &mut IrBuilder) { + self.0.capture(builder); + } fn evaluate_verifier( &self, diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 5174bf66c..dbcceefef 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::marker::PhantomData; use crate::{ + constraint_ir::{Expr, IrBuilder}, constraints::{ boundary::{BoundaryConstraint, BoundaryConstraints}, transition::TransitionConstraintEvaluator, @@ -1708,6 +1709,303 @@ fn compute_fingerprint_from_step, B: IsField>( z - &linear_combination } +// ============================================================================= +// IR capture (constraint-ir): fingerprint / multiplicity / packing +// ============================================================================= +// +// These mirror `Multiplicity::evaluate_with`, `Packing::accumulate_fingerprint_with`, +// `BusValue::accumulate_fingerprint_from_step`, and `compute_fingerprint_from_step` +// above, but build `IrBuilder` nodes instead of evaluating `FieldElement`s. All +// take an explicit `offset` (frame step: 0 = current row, 1 = next row) because +// `LookupAccumulatedConstraint`'s absorbed interactions read the next row, +// whereas `LookupBatchedTermConstraint` reads the current row (see capture +// impls below). Packing shifts (8/16/24) are base-field constants — `const_base` +// reproduces `PackingShifts` exactly, so no separate uniform leaf is needed. +// +// Honesty note (matches the runtime body): `BusValue::Linear`'s data-dependent +// "skip the multiply when the row value is zero" optimization is **not** +// reproduced here — the IR is row-agnostic and always emits the multiply. This +// is value-preserving (adding `0·α` is a no-op) and only costs a few extra +// D1×D3 muls/row; see `thoughts/gpu-constraint-eval/roadmap.md`. + +/// Capture the multiplicity for an interaction, mirroring +/// [`Multiplicity::evaluate_with`] (via [`compute_multiplicity_from_step`]). +fn capture_multiplicity(b: &mut IrBuilder, multiplicity: &Multiplicity, offset: u8) -> Expr { + match multiplicity { + Multiplicity::One => b.one(), + Multiplicity::Column(col) => b.main(offset, *col), + Multiplicity::Sum(a, c) => { + let va = b.main(offset, *a); + let vc = b.main(offset, *c); + b.add(va, vc) + } + Multiplicity::Negated(col) => { + let one = b.one(); + let v = b.main(offset, *col); + b.sub(one, v) + } + Multiplicity::Diff(a, c) => { + let va = b.main(offset, *a); + let vc = b.main(offset, *c); + b.sub(va, vc) + } + Multiplicity::Sum3(a, c, d) => { + let va = b.main(offset, *a); + let vc = b.main(offset, *c); + let vd = b.main(offset, *d); + let r = b.add(va, vc); + b.add(r, vd) + } + Multiplicity::Linear(terms) => capture_linear_terms(b, terms, offset), + } +} + +/// Capture a slice of `LinearTerm`s as a sum, mirroring the `Multiplicity::Linear` +/// arm of [`Multiplicity::evaluate_with`] (`result = Σ terms`, starting at zero). +fn capture_linear_terms(b: &mut IrBuilder, terms: &[LinearTerm], offset: u8) -> Expr { + let mut result = b.const_base(0); + for term in terms { + match *term { + LinearTerm::Column { + coefficient, + column, + } => { + let col = b.main(offset, column); + let coeff = b.const_signed(coefficient); + let term_e = b.mul(col, coeff); + result = b.add(result, term_e); + } + LinearTerm::ColumnUnsigned { + coefficient, + column, + } => { + let col = b.main(offset, column); + let coeff = b.const_base(coefficient); + let term_e = b.mul(col, coeff); + result = b.add(result, term_e); + } + LinearTerm::Constant(value) => { + let c = b.const_signed(value); + result = b.add(result, c); + } + } + } + result +} + +/// Capture a `Packing`'s fingerprint contribution, mirroring +/// [`Packing::accumulate_fingerprint_with`]. `acc` is updated in place (like the +/// real `&mut acc`); `alpha_powers` is the captured `Expr` for each needed +/// `alpha_power` leaf, fetched lazily via `alpha_power_at`. Returns the number of +/// alpha powers consumed (`= packing.num_bus_elements()`). +fn capture_packing_fingerprint( + b: &mut IrBuilder, + packing: Packing, + start_col: usize, + offset: u8, + alpha_offset: usize, + acc: &mut Expr, +) -> usize { + let alpha = |i: usize, b: &mut IrBuilder| b.alpha_power(alpha_offset + i); + let col = |c: usize, b: &mut IrBuilder| b.main(offset, c); + + match packing { + Packing::Direct => { + let v = col(start_col, b); + let ap = alpha(0, b); + let t = b.mul(v, ap); + *acc = b.add(*acc, t); + 1 + } + Packing::Word2L => { + let shift_16 = b.const_base(SHIFT_16); + let c0 = col(start_col, b); + let c1 = col(start_col + 1, b); + let c1_shifted = b.mul(c1, shift_16); + let combined = b.add(c0, c1_shifted); + let ap = alpha(0, b); + let t = b.mul(combined, ap); + *acc = b.add(*acc, t); + 1 + } + Packing::Word4L => { + let shift_8 = b.const_base(SHIFT_8); + let shift_16 = b.const_base(SHIFT_16); + let shift_24 = b.const_base(SHIFT_8 * SHIFT_16); + let c0 = col(start_col, b); + let c1 = col(start_col + 1, b); + let c2 = col(start_col + 2, b); + let c3 = col(start_col + 3, b); + let c1s = b.mul(c1, shift_8); + let c2s = b.mul(c2, shift_16); + let c3s = b.mul(c3, shift_24); + let combined = b.add(c0, c1s); + let combined = b.add(combined, c2s); + let combined = b.add(combined, c3s); + let ap = alpha(0, b); + let t = b.mul(combined, ap); + *acc = b.add(*acc, t); + 1 + } + Packing::DWordWL => { + let c0 = col(start_col, b); + let ap0 = alpha(0, b); + let t0 = b.mul(c0, ap0); + *acc = b.add(*acc, t0); + let c1 = col(start_col + 1, b); + let ap1 = alpha(1, b); + let t1 = b.mul(c1, ap1); + *acc = b.add(*acc, t1); + 2 + } + Packing::DWordHHW => { + let c0 = col(start_col, b); + let ap0 = alpha(0, b); + let t0 = b.mul(c0, ap0); + *acc = b.add(*acc, t0); + let shift_16 = b.const_base(SHIFT_16); + let c1 = col(start_col + 1, b); + let c2 = col(start_col + 2, b); + let c2_shifted = b.mul(c2, shift_16); + let w = b.add(c1, c2_shifted); + let ap1 = alpha(1, b); + let t1 = b.mul(w, ap1); + *acc = b.add(*acc, t1); + 2 + } + Packing::DWordWHH => { + let shift_16 = b.const_base(SHIFT_16); + let c0 = col(start_col, b); + let c1 = col(start_col + 1, b); + let c1_shifted = b.mul(c1, shift_16); + let w = b.add(c0, c1_shifted); + let ap0 = alpha(0, b); + let t0 = b.mul(w, ap0); + *acc = b.add(*acc, t0); + let c2 = col(start_col + 2, b); + let ap1 = alpha(1, b); + let t1 = b.mul(c2, ap1); + *acc = b.add(*acc, t1); + 2 + } + Packing::DWordHL => { + let shift_16 = b.const_base(SHIFT_16); + let c0 = col(start_col, b); + let c1 = col(start_col + 1, b); + let c1_shifted = b.mul(c1, shift_16); + let w0 = b.add(c0, c1_shifted); + let ap0 = alpha(0, b); + let t0 = b.mul(w0, ap0); + *acc = b.add(*acc, t0); + let c2 = col(start_col + 2, b); + let c3 = col(start_col + 3, b); + let c3_shifted = b.mul(c3, shift_16); + let w1 = b.add(c2, c3_shifted); + let ap1 = alpha(1, b); + let t1 = b.mul(w1, ap1); + *acc = b.add(*acc, t1); + 2 + } + Packing::DWordBL => { + let shift_8 = b.const_base(SHIFT_8); + let shift_16 = b.const_base(SHIFT_16); + let shift_24 = b.const_base(SHIFT_8 * SHIFT_16); + let c0 = col(start_col, b); + let c1 = col(start_col + 1, b); + let c2 = col(start_col + 2, b); + let c3 = col(start_col + 3, b); + let c1s = b.mul(c1, shift_8); + let c2s = b.mul(c2, shift_16); + let c3s = b.mul(c3, shift_24); + let w0 = b.add(c0, c1s); + let w0 = b.add(w0, c2s); + let w0 = b.add(w0, c3s); + let ap0 = alpha(0, b); + let t0 = b.mul(w0, ap0); + *acc = b.add(*acc, t0); + let c4 = col(start_col + 4, b); + let c5 = col(start_col + 5, b); + let c6 = col(start_col + 6, b); + let c7 = col(start_col + 7, b); + let c5s = b.mul(c5, shift_8); + let c6s = b.mul(c6, shift_16); + let c7s = b.mul(c7, shift_24); + let w1 = b.add(c4, c5s); + let w1 = b.add(w1, c6s); + let w1 = b.add(w1, c7s); + let ap1 = alpha(1, b); + let t1 = b.mul(w1, ap1); + *acc = b.add(*acc, t1); + 2 + } + Packing::QuadHL => { + let shift_16 = b.const_base(SHIFT_16); + for i in 0..4 { + let c = start_col + i * 2; + let c0 = col(c, b); + let c1 = col(c + 1, b); + let c1_shifted = b.mul(c1, shift_16); + let w = b.add(c0, c1_shifted); + let ap = alpha(i, b); + let t = b.mul(w, ap); + *acc = b.add(*acc, t); + } + 4 + } + Packing::QuadWL => { + for i in 0..4 { + let v = col(start_col + i, b); + let ap = alpha(i, b); + let t = b.mul(v, ap); + *acc = b.add(*acc, t); + } + 4 + } + } +} + +/// Capture a `BusValue`'s fingerprint contribution, mirroring +/// [`BusValue::accumulate_fingerprint_from_step`]. Returns the number of alpha +/// powers consumed. +fn capture_busvalue_fingerprint( + b: &mut IrBuilder, + bv: &BusValue, + offset: u8, + alpha_offset: usize, + acc: &mut Expr, +) -> usize { + match bv { + BusValue::Packed { + start_column, + packing, + } => capture_packing_fingerprint(b, *packing, *start_column, offset, alpha_offset, acc), + BusValue::Linear(terms) => { + // Mirrors the runtime zero-skip's *value* (not its data-dependent + // skip, see the module-level honesty note): the IR always emits the + // multiply. + let result = capture_linear_terms(b, terms, offset); + let ap = b.alpha_power(alpha_offset); + let t = b.mul(result, ap); + *acc = b.add(*acc, t); + 1 + } + } +} + +/// Capture an interaction's fingerprint, mirroring [`compute_fingerprint_from_step`]: +/// `z - (bus_id + α·v[0] + α²·v[1] + ...)`. +fn capture_fingerprint(b: &mut IrBuilder, interaction: &BusInteraction, offset: u8) -> Expr { + let z = b.challenge(0); + // α⁰ = 1: the bus-id term needs no multiply — a base constant added directly, + // matching `FieldElement::::from(interaction.bus_id)`. + let mut lc = b.const_base(interaction.bus_id); + let mut alpha_idx = 1; + for bv in &interaction.values { + alpha_idx += capture_busvalue_fingerprint(b, bv, offset, alpha_idx, &mut lc); + } + b.sub(z, lc) +} + /// Constraint for a batched pair of interactions sharing one aux column. /// /// Verifies: `c = m_a/fp_a + m_b/fp_b` where signs are baked into m_a, m_b. @@ -1828,6 +2126,37 @@ where *eval = res; } } + + fn capture(&self, b: &mut IrBuilder) { + // c * fp_a * fp_b - sign_a * m_a * fp_b - sign_b * m_b * fp_a + // Mirrors `evaluate_batched_term_constraint` above: `is_sender` is a + // compile-time bool, resolved here as `add` vs `neg` instead of an E×E + // sign multiply (same optimization as the runtime body). + let c = b.aux(0, self.term_column_idx); + let m_a = capture_multiplicity(b, &self.interaction_a.multiplicity, 0); + let m_b = capture_multiplicity(b, &self.interaction_b.multiplicity, 0); + let fp_a = capture_fingerprint(b, &self.interaction_a, 0); + let fp_b = capture_fingerprint(b, &self.interaction_b, 0); + + let term_a = b.mul(m_a, fp_b); + let term_a = if self.interaction_a.is_sender { + term_a + } else { + b.neg(term_a) + }; + let term_b = b.mul(m_b, fp_a); + let term_b = if self.interaction_b.is_sender { + term_b + } else { + b.neg(term_b) + }; + + let main = b.mul(c, fp_a); + let main = b.mul(main, fp_b); + let root = b.sub(main, term_a); + let root = b.sub(root, term_b); + b.emit(self.constraint_idx, root); + } } /// Constraint for the accumulated column with absorbed interactions. @@ -2003,4 +2332,69 @@ where *eval = res; } } + + fn capture(&self, b: &mut IrBuilder) { + // Mirrors `evaluate_accumulated_constraint` above. `acc_curr` reads the + // current row (offset 0); `acc_next`/`terms_sum`/the absorbed + // fingerprints+multiplicities all read the *next* row (offset 1) — + // this is the one constraint in the IR migration that needs offset 1. + let acc_curr = b.aux(0, self.acc_column_idx); + let acc_next = b.aux(1, self.acc_column_idx); + + // Sum of all committed term columns at the next step. + let mut terms_sum = b.const_base(0); + for i in 0..self.num_term_columns { + let t = b.aux(1, i); + terms_sum = b.add(terms_sum, t); + } + + // delta = acc_next - acc_curr - terms_sum + L/N + let offset = b.table_offset(); + let delta = b.sub(acc_next, acc_curr); + let delta = b.sub(delta, terms_sum); + let delta = b.add(delta, offset); + + let root = match self.absorbed.len() { + 1 => { + // delta * f - sign * m + let m = capture_multiplicity(b, &self.absorbed[0].multiplicity, 1); + let f = capture_fingerprint(b, &self.absorbed[0], 1); + let mt = if self.absorbed[0].is_sender { + m + } else { + b.neg(m) + }; + let lhs = b.mul(delta, f); + b.sub(lhs, mt) + } + 2 => { + // delta * f1 * f2 - sign1*m1*f2 - sign2*m2*f1 + let m1 = capture_multiplicity(b, &self.absorbed[0].multiplicity, 1); + let m2 = capture_multiplicity(b, &self.absorbed[1].multiplicity, 1); + let f1 = capture_fingerprint(b, &self.absorbed[0], 1); + let f2 = capture_fingerprint(b, &self.absorbed[1], 1); + + let term1 = b.mul(m1, f2); + let term1 = if self.absorbed[0].is_sender { + term1 + } else { + b.neg(term1) + }; + let term2 = b.mul(m2, f1); + let term2 = if self.absorbed[1].is_sender { + term2 + } else { + b.neg(term2) + }; + + let lhs = b.mul(delta, f1); + let lhs = b.mul(lhs, f2); + let r = b.sub(lhs, term1); + b.sub(r, term2) + } + _ => unreachable!("absorbed must contain 1 or 2 interactions"), + }; + + b.emit(self.constraint_idx, root); + } } diff --git a/crypto/stark/src/traits.rs b/crypto/stark/src/traits.rs index 06465b659..489546cee 100644 --- a/crypto/stark/src/traits.rs +++ b/crypto/stark/src/traits.rs @@ -10,6 +10,7 @@ use math::{ }; use crate::{ + constraint_ir::{ConstraintProgram, IrBuilder}, constraints::transition::TransitionConstraintEvaluator, domain::Domain, lookup::{BusPublicInputs, PackingShifts}, @@ -315,6 +316,25 @@ pub trait AIR: Send + Sync { &self, ) -> &Vec>>; + /// Capture every transition constraint into one flat [`ConstraintProgram`]. + /// + /// Calls `capture` on each boxed constraint (object-safe, see + /// [`TransitionConstraintEvaluator::capture`]) into a single [`IrBuilder`], + /// so `roots[c]` ends up indexed by `constraint_idx()` exactly like + /// `transition_constraints()` itself. `num_base` is + /// `num_base_transition_constraints()`, matching the prover's base/ext + /// split (`compute_transition_prover`). + /// + /// Not cached here — callers (e.g. `ConstraintEvaluator::new`) that need + /// the program on every prove/verify should cache it once. + fn constraint_program(&self) -> ConstraintProgram { + let mut builder = IrBuilder::new(); + for c in self.transition_constraints() { + c.capture(&mut builder); + } + builder.finish(self.num_base_transition_constraints()) + } + /// Compute zerofier evaluations as deduplicated groups with index mapping. /// /// Each unique zerofier (keyed by period/offset/exemption parameters) is diff --git a/prover/src/constraints/cpu.rs b/prover/src/constraints/cpu.rs index 1c811471b..c2d82df41 100644 --- a/prover/src/constraints/cpu.rs +++ b/prover/src/constraints/cpu.rs @@ -17,7 +17,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; -use stark::constraint_ir::{Capture, IrBuilder}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::table::TableView; @@ -72,6 +72,21 @@ where + step.get_main_evaluation_element(0, hi_col) * shift_16 } +/// Capture `cast(res, DWordWL)` low/high words, mirroring [`res_word`]. +#[inline] +fn capture_res_word(b: &mut IrBuilder, high: bool) -> Expr { + let (lo_col, hi_col) = if high { + (cols::RES_2, cols::RES_3) + } else { + (cols::RES_0, cols::RES_1) + }; + let lo = b.main(0, lo_col); + let hi = b.main(0, hi_col); + let shift_16 = b.const_base(SHIFT_16); + let hi_shifted = b.mul(hi, shift_16); + b.add(lo, hi_shifted) +} + // ========================================================================= // decode group: word_instr mutex // ========================================================================= @@ -163,6 +178,22 @@ impl TransitionConstraint for Arg2Exclusiv } } +impl Capture for Arg2ExclusiveConstraint { + fn capture(&self, b: &mut IrBuilder) { + // (1 - memory - branch) * rr2 * imm + let one = b.one(); + let memory = b.main(0, cols::MEMORY); + let branch = b.main(0, cols::BRANCH); + let rr2 = b.main(0, cols::READ_REGISTER2); + let imm = b.main(0, self.imm_col); + let coeff = b.sub(one, memory); + let coeff = b.sub(coeff, branch); + let coeff_rr2 = b.mul(coeff, rr2); + let root = b.mul(coeff_rr2, imm); + b.emit(self.constraint_idx, root); + } +} + /// `IS_BIT` on non-MEMORY rows: `(1 - MEMORY) · mem_flags · (1 - mem_flags) = 0`. /// On non-memory rows `mem_flags` carries only the JALR bit, so it must be 0/1. /// A spec defense-in-depth assumption (the DECODE lookup already enforces it). @@ -197,6 +228,21 @@ impl TransitionConstraint for MemFlagsBitC } } +impl Capture for MemFlagsBitConstraint { + fn capture(&self, b: &mut IrBuilder) { + // (1 - memory) * mem_flags * (1 - mem_flags) + let one = b.one(); + let memory = b.main(0, cols::MEMORY); + let mem_flags = b.main(0, cols::MEM_FLAGS); + let one_minus_memory = b.sub(one, memory); + let one2 = b.one(); + let one_minus_mem_flags = b.sub(one2, mem_flags); + let lhs = b.mul(one_minus_memory, mem_flags); + let root = b.mul(lhs, one_minus_mem_flags); + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // mem group: register zero-forcing // ========================================================================= @@ -240,6 +286,18 @@ impl TransitionConstraint for RegNotReadIs } } +impl Capture for RegNotReadIsZeroConstraint { + fn capture(&self, b: &mut IrBuilder) { + // (1 - flag) * value + let one = b.one(); + let flag = b.main(0, self.flag_col); + let value = b.main(0, self.value_col); + let one_minus_flag = b.sub(one, flag); + let root = b.mul(one_minus_flag, value); + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // alu group: arg2 multiplex // ========================================================================= @@ -314,6 +372,38 @@ impl TransitionConstraint for Arg2Constrai } } +impl Capture for Arg2Constraint { + fn capture(&self, b: &mut IrBuilder) { + let (arg2_col, imm_col, rv2_col) = if self.word_idx == 0 { + (cols::ARG2_0, cols::IMM_0, cols::RV2_0) + } else { + (cols::ARG2_1, cols::IMM_1, cols::RV2_1) + }; + + let one = b.one(); + let arg2 = b.main(0, arg2_col); + let imm = b.main(0, imm_col); + let rv2 = b.main(0, rv2_col); + let memory = b.main(0, cols::MEMORY); + let branch = b.main(0, cols::BRANCH); + + // MEMORY * imm + let memory_imm = b.mul(memory, imm); + // BRANCH * rv2 + let branch_rv2 = b.mul(branch, rv2); + // (1 - MEMORY - BRANCH) * (rv2 + imm) + let coeff = b.sub(one, memory); + let coeff = b.sub(coeff, branch); + let rv2_imm = b.add(rv2, imm); + let last_term = b.mul(coeff, rv2_imm); + + let expected = b.add(memory_imm, branch_rv2); + let expected = b.add(expected, last_term); + let root = b.sub(arg2, expected); + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // mem group: ¬MEMORY ∧ ¬JALR ⇒ rvd = cast(res, WL) // ========================================================================= @@ -365,6 +455,25 @@ impl TransitionConstraint for RvdEqResCons } } +impl Capture for RvdEqResConstraint { + fn capture(&self, b: &mut IrBuilder) { + let high = self.word_idx == 1; + let rvd_col = if high { cols::RVD_1 } else { cols::RVD_0 }; + let one = b.one(); + let memory = b.main(0, cols::MEMORY); + let branch = b.main(0, cols::BRANCH); + let rvd = b.main(0, rvd_col); + let res_w = capture_res_word(b, high); + + // (1 - memory - branch) * (rvd - res_w) + let coeff = b.sub(one, memory); + let coeff = b.sub(coeff, branch); + let diff = b.sub(rvd, res_w); + let root = b.mul(coeff, diff); + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // branch group: BRANCH ⇒ rvd = pc + instruction_length // ========================================================================= @@ -428,6 +537,29 @@ impl BranchRvdConstraint { let inv_2_32 = FieldElement::::from(super::templates::INV_SHIFT_32); (pc_hi + carry_0 - rvd_hi) * inv_2_32 } + + /// Capture carry_0, mirroring [`Self::compute_carry_0`]. + fn capture_carry_0(&self, b: &mut IrBuilder) -> Expr { + let pc_lo = b.main(0, cols::PC_0); + let rvd_lo = b.main(0, cols::RVD_0); + let half_len = b.main(0, cols::HALF_INSTRUCTION_LENGTH); + let instr_len = b.add(half_len, half_len); // real byte length = 2 * half + let inv_2_32 = b.const_base(super::templates::INV_SHIFT_32); + let s = b.add(pc_lo, instr_len); + let s = b.sub(s, rvd_lo); + b.mul(s, inv_2_32) + } + + /// Capture carry_1, mirroring [`Self::compute_carry_1`]. + fn capture_carry_1(&self, b: &mut IrBuilder) -> Expr { + let pc_hi = b.main(0, cols::PC_1); + let rvd_hi = b.main(0, cols::RVD_1); + let carry_0 = self.capture_carry_0(b); + let inv_2_32 = b.const_base(super::templates::INV_SHIFT_32); + let s = b.add(pc_hi, carry_0); + let s = b.sub(s, rvd_hi); + b.mul(s, inv_2_32) + } } impl TransitionConstraint for BranchRvdConstraint { @@ -456,6 +588,23 @@ impl TransitionConstraint for BranchRvdCon } } +impl Capture for BranchRvdConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + let branch = b.main(0, cols::BRANCH); + let carry = match self.carry_idx { + 0 => self.capture_carry_0(b), + 1 => self.capture_carry_1(b), + _ => unreachable!("carry_idx validated <= 1 at construction"), + }; + // branch * carry * (1 - carry) + let one_minus_carry = b.sub(one, carry); + let branch_carry = b.mul(branch, carry); + let root = b.mul(branch_carry, one_minus_carry); + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // branch group: branch_cond // ========================================================================= @@ -499,6 +648,26 @@ impl TransitionConstraint for BranchCondCo } } +impl Capture for BranchCondConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + let branch = b.main(0, cols::BRANCH); + let jalr = b.main(0, cols::MEM_FLAGS); + let res0 = b.main(0, cols::RES_0); + let branch_cond = b.main(0, cols::BRANCH_COND); + + // branch*jalr + branch*(1-jalr)*res0 + let branch_jalr = b.mul(branch, jalr); + let one_minus_jalr = b.sub(one, jalr); + let branch_one_minus_jalr = b.mul(branch, one_minus_jalr); + let second_term = b.mul(branch_one_minus_jalr, res0); + let expected = b.add(branch_jalr, second_term); + + let root = b.sub(branch_cond, expected); + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // branch group: next_pc = pc + instruction_length (when not branching) // ========================================================================= @@ -552,6 +721,29 @@ impl NextPcAddConstraint { let inv_2_32 = FieldElement::::from(super::templates::INV_SHIFT_32); (pc_hi + carry_0 - next_pc_hi) * inv_2_32 } + + /// Capture carry_0, mirroring [`Self::compute_carry_0`]. + fn capture_carry_0(&self, b: &mut IrBuilder) -> Expr { + let pc_lo = b.main(0, cols::PC_0); + let next_pc_lo = b.main(0, cols::NEXT_PC_0); + let half_len = b.main(0, cols::HALF_INSTRUCTION_LENGTH); + let instr_len = b.add(half_len, half_len); // real byte length = 2 * half + let inv_2_32 = b.const_base(super::templates::INV_SHIFT_32); + let s = b.add(pc_lo, instr_len); + let s = b.sub(s, next_pc_lo); + b.mul(s, inv_2_32) + } + + /// Capture carry_1, mirroring [`Self::compute_carry_1`]. + fn capture_carry_1(&self, b: &mut IrBuilder) -> Expr { + let pc_hi = b.main(0, cols::PC_1); + let next_pc_hi = b.main(0, cols::NEXT_PC_1); + let carry_0 = self.capture_carry_0(b); + let inv_2_32 = b.const_base(super::templates::INV_SHIFT_32); + let s = b.add(pc_hi, carry_0); + let s = b.sub(s, next_pc_hi); + b.mul(s, inv_2_32) + } } impl TransitionConstraint for NextPcAddConstraint { @@ -582,6 +774,24 @@ impl TransitionConstraint for NextPcAddCon } } +impl Capture for NextPcAddConstraint { + fn capture(&self, b: &mut IrBuilder) { + let branch_cond = b.main(0, cols::BRANCH_COND); + let one = b.one(); + let not_branch = b.sub(one, branch_cond); + let carry = match self.carry_idx { + 0 => self.capture_carry_0(b), + 1 => self.capture_carry_1(b), + _ => unreachable!("carry_idx validated <= 1 at construction"), + }; + let one2 = b.one(); + let one_minus_carry = b.sub(one2, carry); + let not_branch_carry = b.mul(not_branch, carry); + let root = b.mul(not_branch_carry, one_minus_carry); + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // alu group: ADD / SUB fast-path templates // ========================================================================= diff --git a/prover/src/tables/branch.rs b/prover/src/tables/branch.rs index 9443a81a1..650f6e895 100644 --- a/prover/src/tables/branch.rs +++ b/prover/src/tables/branch.rs @@ -28,6 +28,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::TransitionConstraint; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -514,6 +515,59 @@ impl BranchConstraint { } } } + + /// Capture virtual next_pc_unmasked as DWordWL, mirroring + /// [`Self::compute_next_pc_unmasked`]. + fn capture_next_pc_unmasked(b: &mut IrBuilder) -> (Expr, Expr) { + let unmasked_low_byte = b.main(0, cols::UNMASKED_LOW_BYTE); + let next_pc_low_1 = b.main(0, cols::NEXT_PC_LOW_1); + let next_pc_high_0 = b.main(0, cols::NEXT_PC_HIGH_0); + let next_pc_high_1 = b.main(0, cols::NEXT_PC_HIGH_1); + let next_pc_high_2 = b.main(0, cols::NEXT_PC_HIGH_2); + + let shift_8 = b.const_base(SHIFT_8); + let shift_16 = b.const_base(SHIFT_16); + + // unmasked_low_byte + next_pc_low_1 * shift_8 + next_pc_high_0 * shift_16 + let t1 = b.mul(next_pc_low_1, shift_8); + let t2 = b.mul(next_pc_high_0, shift_16); + let unmasked_0 = b.add(unmasked_low_byte, t1); + let unmasked_0 = b.add(unmasked_0, t2); + + // next_pc_high_1 + next_pc_high_2 * shift_16 + let t3 = b.mul(next_pc_high_2, shift_16); + let unmasked_1 = b.add(next_pc_high_1, t3); + + (unmasked_0, unmasked_1) + } + + /// Capture carry_0 for a given base column, mirroring [`Self::compute_carry_0_for`]. + fn capture_carry_0_for(base_col_0: usize, b: &mut IrBuilder) -> Expr { + let base_0 = b.main(0, base_col_0); + let offset_0 = b.main(0, cols::OFFSET_0); + let (unmasked_0, _) = Self::capture_next_pc_unmasked(b); + + let inv_2_32 = b.const_base(crate::constraints::templates::INV_SHIFT_32); + // (base_0 + offset_0 - unmasked_0) * inv_2_32 + let s = b.add(base_0, offset_0); + let s = b.sub(s, unmasked_0); + b.mul(s, inv_2_32) + } + + /// Capture carry_1 for a given base column pair, mirroring [`Self::compute_carry_1_for`]. + fn capture_carry_1_for(base_col_0: usize, base_col_1: usize, b: &mut IrBuilder) -> Expr { + let base_1 = b.main(0, base_col_1); + let offset_1 = b.main(0, cols::OFFSET_1); + let carry_0 = Self::capture_carry_0_for(base_col_0, b); + let (_, unmasked_1) = Self::capture_next_pc_unmasked(b); + + let inv_2_32 = b.const_base(crate::constraints::templates::INV_SHIFT_32); + // (base_1 + offset_1 + carry_0 - unmasked_1) * inv_2_32 + let s = b.add(base_1, offset_1); + let s = b.add(s, carry_0); + let s = b.sub(s, unmasked_1); + b.mul(s, inv_2_32) + } } impl TransitionConstraint for BranchConstraint { @@ -539,6 +593,53 @@ impl TransitionConstraint for BranchConstr } } +impl Capture for BranchConstraint { + fn capture(&self, b: &mut IrBuilder) { + let jalr = b.main(0, cols::JALR); + let one = b.one(); + + let root = match self.kind { + BranchConstraintKind::JalrIsBit => { + // jalr * (1 - jalr) + let one_minus_jalr = b.sub(one, jalr); + b.mul(jalr, one_minus_jalr) + } + BranchConstraintKind::PcCarry0IsBit => { + // (1 - jalr) * c * (1 - c) + let cond = b.sub(one, jalr); + let c = Self::capture_carry_0_for(cols::PC_0, b); + let one2 = b.one(); + let one_minus_c = b.sub(one2, c); + let cond_c = b.mul(cond, c); + b.mul(cond_c, one_minus_c) + } + BranchConstraintKind::PcCarry1IsBit => { + let cond = b.sub(one, jalr); + let c = Self::capture_carry_1_for(cols::PC_0, cols::PC_1, b); + let one2 = b.one(); + let one_minus_c = b.sub(one2, c); + let cond_c = b.mul(cond, c); + b.mul(cond_c, one_minus_c) + } + BranchConstraintKind::RegCarry0IsBit => { + // cond = jalr + let c = Self::capture_carry_0_for(cols::REGISTER_0, b); + let one_minus_c = b.sub(one, c); + let cond_c = b.mul(jalr, c); + b.mul(cond_c, one_minus_c) + } + BranchConstraintKind::RegCarry1IsBit => { + let c = Self::capture_carry_1_for(cols::REGISTER_0, cols::REGISTER_1, b); + let one_minus_c = b.sub(one, c); + let cond_c = b.mul(jalr, c); + b.mul(cond_c, one_minus_c) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Creates all constraints for the BRANCH table. /// /// Returns 5 constraints (two conditional ADD templates × 2 carries each, plus diff --git a/prover/src/tables/commit.rs b/prover/src/tables/commit.rs index c1663711e..442557961 100644 --- a/prover/src/tables/commit.rs +++ b/prover/src/tables/commit.rs @@ -45,6 +45,7 @@ //! use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -851,3 +852,23 @@ impl TransitionConstraint for CommitConstr self.compute(step) } } + +impl Capture for CommitConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + + let root = match self.kind { + CommitConstraintKind::FirstOrEndImpliesMu => { + let first = b.main(0, cols::FIRST); + let end = b.main(0, cols::END); + let mu = b.main(0, cols::MU); + // (first + end) * (1 - mu) + let sum = b.add(first, end); + let one_minus_mu = b.sub(one, mu); + b.mul(sum, one_minus_mu) + } + }; + + b.emit(self.constraint_idx, root); + } +} diff --git a/prover/src/tables/cpu32.rs b/prover/src/tables/cpu32.rs index d7dbd5d6f..4727cca9b 100644 --- a/prover/src/tables/cpu32.rs +++ b/prover/src/tables/cpu32.rs @@ -19,6 +19,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -710,6 +711,101 @@ impl TransitionConstraint for Cpu32Constra } } +impl Capture for Cpu32Constraint { + fn capture(&self, b: &mut IrBuilder) { + let shift16 = b.const_base(SHIFT_16); + let hi_fill = b.const_base(HI_FILL); + let one = b.one(); + + let root = match self.kind { + Cpu32ConstraintKind::Arg1Lo => { + // arg1[0] - rv1[0] - shift16*rv1[1] + let arg1_0 = b.main(0, cols::ARG1_0); + let rv1_0 = b.main(0, cols::RV1_0); + let rv1_1 = b.main(0, cols::RV1_1); + let shifted = b.mul(shift16, rv1_1); + let r = b.sub(arg1_0, rv1_0); + b.sub(r, shifted) + } + Cpu32ConstraintKind::Arg1Hi => { + // arg1[1] - hi_fill*rv1_sign + let arg1_1 = b.main(0, cols::ARG1_1); + let rv1_sign = b.main(0, cols::RV1_SIGN); + let fill = b.mul(hi_fill, rv1_sign); + b.sub(arg1_1, fill) + } + Cpu32ConstraintKind::Arg2Lo => { + // arg2[0] - rv2[0] - shift16*rv2[1] - imm[0] + let arg2_0 = b.main(0, cols::ARG2_0); + let rv2_0 = b.main(0, cols::RV2_0); + let rv2_1 = b.main(0, cols::RV2_1); + let imm_0 = b.main(0, cols::IMM_0); + let shifted = b.mul(shift16, rv2_1); + let r = b.sub(arg2_0, rv2_0); + let r = b.sub(r, shifted); + b.sub(r, imm_0) + } + Cpu32ConstraintKind::Arg2Hi => { + // arg2[1] - hi_fill*rv2_sign - imm[1] + let arg2_1 = b.main(0, cols::ARG2_1); + let rv2_sign = b.main(0, cols::RV2_SIGN); + let imm_1 = b.main(0, cols::IMM_1); + let fill = b.mul(hi_fill, rv2_sign); + let r = b.sub(arg2_1, fill); + b.sub(r, imm_1) + } + Cpu32ConstraintKind::RvdLo => { + // rvd[0] - res[0] - shift16*res[1] + let rvd_0 = b.main(0, cols::RVD_0); + let res_0 = b.main(0, cols::RES_0); + let res_1 = b.main(0, cols::RES_1); + let shifted = b.mul(shift16, res_1); + let r = b.sub(rvd_0, res_0); + b.sub(r, shifted) + } + Cpu32ConstraintKind::RvdHi => { + // rvd[1] - hi_fill*res_sign + let rvd_1 = b.main(0, cols::RVD_1); + let res_sign = b.main(0, cols::RES_SIGN); + let fill = b.mul(hi_fill, res_sign); + b.sub(rvd_1, fill) + } + Cpu32ConstraintKind::RegZero { + read_col, + value_col, + } => { + // (1 - read) * value + let read = b.main(0, read_col); + let value = b.main(0, value_col); + let one_minus_read = b.sub(one, read); + b.mul(one_minus_read, value) + } + Cpu32ConstraintKind::Arg2Exclusive { imm_col } => { + // read_register2 * imm + let rr2 = b.main(0, cols::READ_REGISTER2); + let imm = b.main(0, imm_col); + b.mul(rr2, imm) + } + Cpu32ConstraintKind::FlagImpliesMu { flag_col } => { + // (1 - mu) * flag + let mu = b.main(0, cols::MU); + let flag = b.main(0, flag_col); + let one_minus_mu = b.sub(one, mu); + b.mul(one_minus_mu, flag) + } + Cpu32ConstraintKind::SignZeroWhenUnsigned { sign_col } => { + // (1 - signed) * sign + let signed = b.main(0, cols::SIGNED); + let sign = b.main(0, sign_col); + let one_minus_signed = b.sub(one, signed); + b.mul(one_minus_signed, sign) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Creates all transition constraints for the CPU32 table: /// `IS_BIT` on the flag columns, the `ADD`/`SUB` fast-path carries, the /// register-zero checks, and the sign-extension `ext` arithmetic. diff --git a/prover/src/tables/dvrm.rs b/prover/src/tables/dvrm.rs index 3da78dff5..d3b1808a9 100644 --- a/prover/src/tables/dvrm.rs +++ b/prover/src/tables/dvrm.rs @@ -31,6 +31,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::TransitionConstraint; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -1214,6 +1215,76 @@ impl DvrmConstraint { ext_word, ] } + + /// Capture carry[i], mirroring [`Self::compute_carry`]. + fn capture_carry(&self, i: usize, b: &mut IrBuilder) -> Expr { + let shift_16 = b.const_base(SHIFT_16); + let inv_2_32 = b.const_base(crate::constraints::templates::INV_SHIFT_32); + let sign_fill = b.const_base(SIGN_FILL); + + let n: [Expr; 4] = [ + b.main(0, cols::N_0), + b.main(0, cols::N_1), + b.main(0, cols::N_2), + b.main(0, cols::N_3), + ]; + let nsr: [Expr; 4] = [ + b.main(0, cols::N_SUB_R_0), + b.main(0, cols::N_SUB_R_1), + b.main(0, cols::N_SUB_R_2), + b.main(0, cols::N_SUB_R_3), + ]; + let r: [Expr; 4] = [ + b.main(0, cols::R_0), + b.main(0, cols::R_1), + b.main(0, cols::R_2), + b.main(0, cols::R_3), + ]; + + let sign_n = b.main(0, cols::SIGN_N); + let sign_r = b.main(0, cols::SIGN_R); + let sign_nsr = b.main(0, cols::SIGN_N_SUB_R); + + let ext_n = Self::capture_extended_quad(&n, sign_n, shift_16, sign_fill, b); + let ext_r = Self::capture_extended_quad(&r, sign_r, shift_16, sign_fill, b); + let ext_nsr = Self::capture_extended_quad(&nsr, sign_nsr, shift_16, sign_fill, b); + + if i == 0 { + // (ext_nsr[0] + ext_r[0] - ext_n[0]) * inv_2_32 + let s = b.add(ext_nsr[0], ext_r[0]); + let s = b.sub(s, ext_n[0]); + b.mul(s, inv_2_32) + } else { + let prev_carry = self.capture_carry(i - 1, b); + // (ext_nsr[i] + ext_r[i] + prev_carry - ext_n[i]) * inv_2_32 + let s = b.add(ext_nsr[i], ext_r[i]); + let s = b.add(s, prev_carry); + let s = b.sub(s, ext_n[i]); + b.mul(s, inv_2_32) + } + } + + /// Capture the sign-extended QuadWL representation, mirroring + /// [`Self::build_extended_quad`]. + fn capture_extended_quad( + halfwords: &[Expr; 4], + sign: Expr, + shift_16: Expr, + sign_fill: Expr, + b: &mut IrBuilder, + ) -> [Expr; 4] { + // ext_word = sign*sign_fill + sign*sign_fill*shift_16 + let sign_fill_term = b.mul(sign, sign_fill); + let sign_fill_shifted = b.mul(sign_fill_term, shift_16); + let ext_word = b.add(sign_fill_term, sign_fill_shifted); + + let h1_shifted = b.mul(halfwords[1], shift_16); + let w0 = b.add(halfwords[0], h1_shifted); + let h3_shifted = b.mul(halfwords[3], shift_16); + let w1 = b.add(halfwords[2], h3_shifted); + + [w0, w1, ext_word, ext_word] + } } impl TransitionConstraint for DvrmConstraint { @@ -1246,6 +1317,136 @@ impl TransitionConstraint for DvrmConstrai } } +impl Capture for DvrmConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + + let root = match self.kind { + DvrmConstraintKind::SignedIsBit => { + // signed * (1 - signed) + let signed = b.main(0, cols::SIGNED); + let one_minus_signed = b.sub(one, signed); + b.mul(signed, one_minus_signed) + } + DvrmConstraintKind::RemainderSignMatchesNumerator => { + // (r[0]+r[1]+r[2]+r[3]) * (sign_r - sign_n) + let r0 = b.main(0, cols::R_0); + let r1 = b.main(0, cols::R_1); + let r2 = b.main(0, cols::R_2); + let r3 = b.main(0, cols::R_3); + let sign_r = b.main(0, cols::SIGN_R); + let sign_n = b.main(0, cols::SIGN_N); + let r_sum = b.add(r0, r1); + let r_sum = b.add(r_sum, r2); + let r_sum = b.add(r_sum, r3); + let sign_diff = b.sub(sign_r, sign_n); + b.mul(r_sum, sign_diff) + } + DvrmConstraintKind::AbsRFormula(i) => { + // (1-sign_r) * (abs_r[i] - (r::DWordWL)[i]) + let sign_r = b.main(0, cols::SIGN_R); + let abs_r_col = if i == 0 { cols::ABS_R_0 } else { cols::ABS_R_1 }; + let abs_r = b.main(0, abs_r_col); + let shift_16 = b.const_base(SHIFT_16); + let r_wl = if i == 0 { + let r0 = b.main(0, cols::R_0); + let r1 = b.main(0, cols::R_1); + let r1_shifted = b.mul(r1, shift_16); + b.add(r0, r1_shifted) + } else { + let r2 = b.main(0, cols::R_2); + let r3 = b.main(0, cols::R_3); + let r3_shifted = b.mul(r3, shift_16); + b.add(r2, r3_shifted) + }; + let one_minus_sign_r = b.sub(one, sign_r); + let diff = b.sub(abs_r, r_wl); + b.mul(one_minus_sign_r, diff) + } + DvrmConstraintKind::AbsDFormula(i) => { + // (1-sign_d) * (abs_d[i] - (d::DWordWL)[i]) + let sign_d = b.main(0, cols::SIGN_D); + let abs_d_col = if i == 0 { cols::ABS_D_0 } else { cols::ABS_D_1 }; + let abs_d = b.main(0, abs_d_col); + let shift_16 = b.const_base(SHIFT_16); + let d_wl = if i == 0 { + let d0 = b.main(0, cols::D_0); + let d1 = b.main(0, cols::D_1); + let d1_shifted = b.mul(d1, shift_16); + b.add(d0, d1_shifted) + } else { + let d2 = b.main(0, cols::D_2); + let d3 = b.main(0, cols::D_3); + let d3_shifted = b.mul(d3, shift_16); + b.add(d2, d3_shifted) + }; + let one_minus_sign_d = b.sub(one, sign_d); + let diff = b.sub(abs_d, d_wl); + b.mul(one_minus_sign_d, diff) + } + DvrmConstraintKind::SignQFormula => { + // signed * (1-overflow) - sign_q + let signed = b.main(0, cols::SIGNED); + let overflow = b.main(0, cols::OVERFLOW); + let sign_q = b.main(0, cols::SIGN_Q); + let one_minus_overflow = b.sub(one, overflow); + let lhs = b.mul(signed, one_minus_overflow); + b.sub(lhs, sign_q) + } + DvrmConstraintKind::CarryIsBit(i) => { + // carry[i] * (1 - carry[i]) + let carry = self.capture_carry(i, b); + let one_minus_carry = b.sub(one, carry); + b.mul(carry, one_minus_carry) + } + DvrmConstraintKind::SignNSubRIsBit => { + // sign_n_sub_r * (1 - sign_n_sub_r) + let sign = b.main(0, cols::SIGN_N_SUB_R); + let one_minus_sign = b.sub(one, sign); + b.mul(sign, one_minus_sign) + } + DvrmConstraintKind::UnsignedSignN => { + // (1-signed) * sign_n + let signed = b.main(0, cols::SIGNED); + let sign_n = b.main(0, cols::SIGN_N); + let one_minus_signed = b.sub(one, signed); + b.mul(one_minus_signed, sign_n) + } + DvrmConstraintKind::UnsignedSignR => { + // (1-signed) * sign_r + let signed = b.main(0, cols::SIGNED); + let sign_r = b.main(0, cols::SIGN_R); + let one_minus_signed = b.sub(one, signed); + b.mul(one_minus_signed, sign_r) + } + DvrmConstraintKind::UnsignedSignD => { + // (1-signed) * sign_d + let signed = b.main(0, cols::SIGNED); + let sign_d = b.main(0, cols::SIGN_D); + let one_minus_signed = b.sub(one, signed); + b.mul(one_minus_signed, sign_d) + } + DvrmConstraintKind::DivByZeroQ(i) => { + // div_by_zero * (q[i] - 65535) + let dbz = b.main(0, cols::DIV_BY_ZERO); + let q_col = match i { + 0 => cols::Q_0, + 1 => cols::Q_1, + 2 => cols::Q_2, + 3 => cols::Q_3, + _ => unreachable!(), + }; + let q = b.main(0, q_col); + let fill = b.const_base(SIGN_FILL); + let diff = b.sub(q, fill); + b.mul(dbz, diff) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Creates all constraints for the DVRM table. /// /// Returns: (constraints, next_constraint_idx) diff --git a/prover/src/tables/ec_scalar.rs b/prover/src/tables/ec_scalar.rs index dd8d483a2..fffa85721 100644 --- a/prover/src/tables/ec_scalar.rs +++ b/prover/src/tables/ec_scalar.rs @@ -18,6 +18,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -312,6 +313,21 @@ impl TransitionConstraint for MulZeroConst } } +impl Capture for MulZeroConstraint { + fn capture(&self, b: &mut IrBuilder) { + let a = b.main(0, self.a); + let bb = b.main(0, self.b); + let root = if self.b_complement { + let one = b.one(); + let one_minus_b = b.sub(one, bb); + b.mul(a, one_minus_b) + } else { + b.mul(a, bb) + }; + b.emit(self.constraint_idx, root); + } +} + /// Creates all EC_SCALAR transition constraints (20 total). pub fn create_constraints( constraint_idx_start: usize, diff --git a/prover/src/tables/ecdas.rs b/prover/src/tables/ecdas.rs index 6d508d363..24236e78a 100644 --- a/prover/src/tables/ecdas.rs +++ b/prover/src/tables/ecdas.rs @@ -12,6 +12,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -358,6 +359,142 @@ impl ConvCarry { } } } + + /// Capture `s_i`, mirroring [`Self::s_i`]. + fn capture_s_i(&self, b: &mut IrBuilder) -> Expr { + let i = self.i; + let col = |c: usize, b: &mut IrBuilder| -> Expr { b.main(0, c) }; + // bytes (zero beyond the stored length) + let byte = |base: usize, len: usize, j: usize, b: &mut IrBuilder| -> Expr { + if j < len { + col(base + j, b) + } else { + b.const_base(0) + } + }; + let lam = |j: usize, b: &mut IrBuilder| byte(cols::LAMBDA, 32, j, b); + let xg = |j: usize, b: &mut IrBuilder| byte(cols::XG, 32, j, b); + let xa = |j: usize, b: &mut IrBuilder| byte(cols::XA, 32, j, b); + let ya = |j: usize, b: &mut IrBuilder| byte(cols::YA, 32, j, b); + let yg = |j: usize, b: &mut IrBuilder| byte(cols::YG, 32, j, b); + let xr = |j: usize, b: &mut IrBuilder| byte(cols::XR, 32, j, b); + let yr = |j: usize, b: &mut IrBuilder| byte(cols::YR, 32, j, b); + let p_byte_e = |m: usize, b: &mut IrBuilder| -> Expr { + if m < 32 { + b.const_base(P_BYTES[m] as u64) + } else { + b.const_base(0) + } + }; + let r_byte_e = |m: usize, b: &mut IrBuilder| -> Expr { + if m < 33 { + b.const_base(R_BYTES[m] as u64) + } else { + b.const_base(0) + } + }; + + // r·P − q·P convolution (shared structure across all three relations). + let rq = |qbase: usize, i: usize, b: &mut IrBuilder| -> Expr { + let mut s = b.const_base(0); + for j in 0..=i { + let r_j = r_byte_e(j, b); + let q_j = byte(qbase, 33, j, b); + let diff = b.sub(r_j, q_j); + let p_ij = p_byte_e(i - j, b); + let term = b.mul(diff, p_ij); + s = b.add(s, term); + } + s + }; + + match self.relation { + Relation::Lambda => { + let op = col(cols::OP, b); + let one = b.one(); + // op·(Σ λ_j(xG-xA)_{i-j} + (yA_i - yG_i)) + let ya_i = ya(i, b); + let yg_i = yg(i, b); + let mut op_branch = b.sub(ya_i, yg_i); + for j in 0..=i { + let lam_j = lam(j, b); + let xg_ij = xg(i - j, b); + let xa_ij = xa(i - j, b); + let diff = b.sub(xg_ij, xa_ij); + let term = b.mul(lam_j, diff); + op_branch = b.add(op_branch, term); + } + // (1-op)·Σ (2 λ_j yA_{i-j} - 3 xA_j xA_{i-j}) + let mut notop_branch = b.const_base(0); + for j in 0..=i { + let two = b.const_base(2); + let lam_j = lam(j, b); + let ya_ij = ya(i - j, b); + let two_lam = b.mul(two, lam_j); + let term_pos = b.mul(two_lam, ya_ij); + + let three = b.const_base(3); + let xa_j = xa(j, b); + let xa_ij = xa(i - j, b); + let three_xa = b.mul(three, xa_j); + let term_neg = b.mul(three_xa, xa_ij); + + let sum_pos = b.add(notop_branch, term_pos); + notop_branch = b.sub(sum_pos, term_neg); + } + let op_term = b.mul(op, op_branch); + let one_minus_op = b.sub(one, op); + let notop_term = b.mul(one_minus_op, notop_branch); + let rq_term = rq(cols::Q0, i, b); + let s = b.add(op_term, notop_term); + b.add(s, rq_term) + } + Relation::Xr => { + let op = col(cols::OP, b); + let one = b.one(); + // Σ λ_j λ_{i-j} − xA_i − xG_i − xR_i − (1-op)(xA_i − xG_i) + rq + let mut s = b.const_base(0); + for j in 0..=i { + let lam_j = lam(j, b); + let lam_ij = lam(i - j, b); + let term = b.mul(lam_j, lam_ij); + s = b.add(s, term); + } + let xa_i = xa(i, b); + let xg_i = xg(i, b); + let xr_i = xr(i, b); + let s = b.sub(s, xa_i); + let s = b.sub(s, xg_i); + let s = b.sub(s, xr_i); + let xa_i2 = xa(i, b); + let xg_i2 = xg(i, b); + let diff = b.sub(xa_i2, xg_i2); + let one_minus_op = b.sub(one, op); + let term2 = b.mul(one_minus_op, diff); + let s = b.sub(s, term2); + let rq_term = rq(cols::Q1, i, b); + b.add(s, rq_term) + } + Relation::Yr => { + // Σ λ_j(xA-xR)_{i-j} − yA_i − yR_i + rq + let mut s = b.const_base(0); + for j in 0..=i { + let lam_j = lam(j, b); + let xa_ij = xa(i - j, b); + let xr_ij = xr(i - j, b); + let diff = b.sub(xa_ij, xr_ij); + let term = b.mul(lam_j, diff); + s = b.add(s, term); + } + let ya_i = ya(i, b); + let yr_i = yr(i, b); + let s = b.sub(s, ya_i); + let s = b.sub(s, yr_i); + let rq_term = rq(cols::Q2, i, b); + b.add(s, rq_term) + } + } + } } impl TransitionConstraint for ConvCarry { @@ -393,6 +530,30 @@ impl TransitionConstraint for ConvCarry { } } +impl Capture for ConvCarry { + fn capture(&self, b: &mut IrBuilder) { + let c_base = match self.relation { + Relation::Lambda => cols::C0, + Relation::Xr => cols::C1, + Relation::Yr => cols::C2, + }; + let c_i = b.main(0, c_base + self.i); + let c_prev = if self.i == 0 { + b.const_base(0) + } else { + b.main(0, c_base + self.i - 1) + }; + let s_i = self.capture_s_i(b); + + // 256·c_i − c_prev − s_i + let two_five_six = b.const_base(256); + let scaled = b.mul(two_five_six, c_i); + let s = b.sub(scaled, c_prev); + let root = b.sub(s, s_i); + b.emit(self.constraint_idx, root); + } +} + /// `col = 0` (unconditional, degree 1). Used for the closing `c_63 = 0`. pub struct ColIsZero { pub col: usize, @@ -415,6 +576,13 @@ impl TransitionConstraint for ColIsZero { } } +impl Capture for ColIsZero { + fn capture(&self, b: &mut IrBuilder) { + let root = b.main(0, self.col); + b.emit(self.constraint_idx, root); + } +} + /// `a · b = 0` or `a · (1 - b) = 0` (degree 2). pub struct MulZero { pub a: usize, @@ -445,6 +613,21 @@ impl TransitionConstraint for MulZero { } } +impl Capture for MulZero { + fn capture(&self, builder: &mut IrBuilder) { + let a = builder.main(0, self.a); + let bval = builder.main(0, self.b); + let root = if self.b_complement { + let one = builder.one(); + let one_minus_b = builder.sub(one, bval); + builder.mul(a, one_minus_b) + } else { + builder.mul(a, bval) + }; + builder.emit(self.constraint_idx, root); + } +} + /// Creates all ECDAS transition constraints (200 total). pub fn create_constraints( constraint_idx_start: usize, diff --git a/prover/src/tables/ecsm.rs b/prover/src/tables/ecsm.rs index f8ec0859d..75f8c4266 100644 --- a/prover/src/tables/ecsm.rs +++ b/prover/src/tables/ecsm.rs @@ -20,6 +20,7 @@ use executor::vm::instruction::execution::ECSM_SYSCALL_NUMBER; use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -658,6 +659,76 @@ impl ConvCarry { } s } + + /// Capture `s_i`, mirroring [`Self::s_i`]. + fn capture_s_i(&self, b: &mut IrBuilder) -> Expr { + let i = self.i; + let byte = |base: usize, len: usize, j: usize, b: &mut IrBuilder| -> Expr { + if j < len { + b.main(0, base + j) + } else { + b.const_base(0) + } + }; + let p_byte_e = |m: usize, b: &mut IrBuilder| -> Expr { + if m < 32 { + b.const_base(P_BYTES[m] as u64) + } else { + b.const_base(0) + } + }; + let mut s = b.const_base(0); + match self.relation { + Relation::X2 => { + // Σ xG_j·xG_{i-j} − x2_i − Σ q0_j·P_{i-j} + for j in 0..=i { + let xg_j = byte(cols::XG, 32, j, b); + let xg_ij = byte(cols::XG, 32, i - j, b); + let term = b.mul(xg_j, xg_ij); + s = b.add(s, term); + + let q0_j = byte(cols::Q0, 32, j, b); + let p_ij = p_byte_e(i - j, b); + let term2 = b.mul(q0_j, p_ij); + s = b.sub(s, term2); + } + let x2_i = byte(cols::X2, 32, i, b); + s = b.sub(s, x2_i); + } + Relation::Yg => { + // Σ (yG_j·yG_{i-j} + P_j·P_{i-j} − x2_j·xG_{i-j} − q1_j·P_{i-j}) − b_i + for j in 0..=i { + let yg_j = byte(cols::YG, 32, j, b); + let yg_ij = byte(cols::YG, 32, i - j, b); + let term = b.mul(yg_j, yg_ij); + s = b.add(s, term); + + let p_j = p_byte_e(j, b); + let p_ij = p_byte_e(i - j, b); + let term2 = b.mul(p_j, p_ij); + s = b.add(s, term2); + + let x2_j = byte(cols::X2, 32, j, b); + let xg_ij = byte(cols::XG, 32, i - j, b); + let term3 = b.mul(x2_j, xg_ij); + s = b.sub(s, term3); + + let q1_j = byte(cols::Q1, 33, j, b); + let p_ij2 = p_byte_e(i - j, b); + let term4 = b.mul(q1_j, p_ij2); + s = b.sub(s, term4); + } + if i == 0 { + // Only the curve constant `b` is gated by `µ`. + let mu = b.main(0, cols::MU); + let b_const = b.const_base(B); + let mu_b = b.mul(mu, b_const); + s = b.sub(s, mu_b); + } + } + } + s + } } impl TransitionConstraint for ConvCarry { @@ -689,6 +760,29 @@ impl TransitionConstraint for ConvCarry { } } +impl Capture for ConvCarry { + fn capture(&self, b: &mut IrBuilder) { + let c_base = match self.relation { + Relation::X2 => cols::C0, + Relation::Yg => cols::C1, + }; + let c_i = b.main(0, c_base + self.i); + let c_prev = if self.i == 0 { + b.const_base(0) + } else { + b.main(0, c_base + self.i - 1) + }; + let s_i = self.capture_s_i(b); + + // 256·c_i − c_prev − s_i + let two_five_six = b.const_base(256); + let scaled = b.mul(two_five_six, c_i); + let s = b.sub(scaled, c_prev); + let root = b.sub(s, s_i); + b.emit(self.constraint_idx, root); + } +} + /// `col = 0` (unconditional, degree 1). Used for the closing `c_63 = 0`. pub struct ColIsZero { pub col: usize, @@ -711,6 +805,13 @@ impl TransitionConstraint for ColIsZero { } } +impl Capture for ColIsZero { + fn capture(&self, b: &mut IrBuilder) { + let root = b.main(0, self.col); + b.emit(self.constraint_idx, root); + } +} + /// The two 256-bit addition-overflow checks (`k < N` and `xR < p`), whose 8 word-carries /// `c` are virtual. Each `c_i = 2^-32·(addend0_i + addend1_i + c_{i-1} − sum_i)`. The addition /// must overflow `2^256` (carry-out `c_7 = 1`), which proves the strict inequality: @@ -781,6 +882,41 @@ where c } +/// Captures the 8 word-carries of the addition for `kind`, mirroring [`carry_chain`]. +fn capture_carry_chain(kind: OverflowKind, b: &mut IrBuilder) -> [Expr; 8] { + let inv = b.const_base(INV_SHIFT_32); + let hl = kind.addend_hl_base(); + let bl = kind.sum_bl_base(); + let mut c: [Expr; 8] = std::array::from_fn(|_| b.const_base(0)); + let mut prev = b.const_base(0); + for (i, slot) in c.iter_mut().enumerate() { + // addend1 word i (from halfwords): hl[2i] + 2^16·hl[2i+1] + let h_lo = b.main(0, hl + 2 * i); + let h_hi = b.main(0, hl + 2 * i + 1); + let shift_16 = b.const_base(1u64 << 16); + let h_hi_scaled = b.mul(h_hi, shift_16); + let addend1 = b.add(h_lo, h_hi_scaled); + + // sum word i (from bytes): Σ bl[4i+b]·2^{8b} + let mut sum = b.const_base(0); + for byte_idx in 0..4 { + let byte = b.main(0, bl + 4 * i + byte_idx); + let shift = b.const_base(1u64 << (8 * byte_idx)); + let term = b.mul(byte, shift); + sum = b.add(sum, term); + } + + let addend0 = b.const_base(kind.const_word(i)); + let s = b.add(addend0, addend1); + let s = b.add(s, prev); + let s = b.sub(s, sum); + let ci = b.mul(s, inv); + *slot = ci; + prev = ci; + } + c +} + /// `µ · c_i · (1 - c_i) = 0` for a virtual carry bit (degree 3, since `c_i` is linear). pub struct CarryBit { pub kind: OverflowKind, @@ -807,6 +943,18 @@ impl TransitionConstraint for CarryBit { } } +impl Capture for CarryBit { + fn capture(&self, b: &mut IrBuilder) { + let c = capture_carry_chain(self.kind, b); + let mu = b.main(0, cols::MU); + let one = b.one(); + let one_minus_ci = b.sub(one, c[self.i]); + let mu_ci = b.mul(mu, c[self.i]); + let root = b.mul(mu_ci, one_minus_ci); + b.emit(self.constraint_idx, root); + } +} + /// `µ · (1 - c_7) = 0`: the top carry must be 1 (the addition overflows). pub struct OverflowRequired { pub kind: OverflowKind, @@ -831,6 +979,17 @@ impl TransitionConstraint for OverflowRequ } } +impl Capture for OverflowRequired { + fn capture(&self, b: &mut IrBuilder) { + let c = capture_carry_chain(self.kind, b); + let mu = b.main(0, cols::MU); + let one = b.one(); + let one_minus_c7 = b.sub(one, c[7]); + let root = b.mul(mu, one_minus_c7); + b.emit(self.constraint_idx, root); + } +} + /// Creates all ECSM transition constraints (148 total). pub fn create_constraints( constraint_idx_start: usize, diff --git a/prover/src/tables/eq.rs b/prover/src/tables/eq.rs index 453caa928..d4f841257 100644 --- a/prover/src/tables/eq.rs +++ b/prover/src/tables/eq.rs @@ -23,6 +23,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -282,6 +283,24 @@ impl TransitionConstraint for EqXorConstra } } +impl Capture for EqXorConstraint { + fn capture(&self, b: &mut IrBuilder) { + let res = b.main(0, cols::RES); + let eq = b.main(0, cols::EQ); + let invert = b.main(0, cols::INVERT); + let two = b.const_base(2); + + // res - (eq + invert - 2*eq*invert) + let two_eq = b.mul(two, eq); + let two_eq_invert = b.mul(two_eq, invert); + let sum = b.add(eq, invert); + let inner = b.sub(sum, two_eq_invert); + let root = b.sub(res, inner); + + b.emit(self.constraint_idx, root); + } +} + /// Creates all transition constraints for the EQ table. /// /// Returns the boxed constraints and the next available constraint index: diff --git a/prover/src/tables/keccak.rs b/prover/src/tables/keccak.rs index 0f305255b..d5a278088 100644 --- a/prover/src/tables/keccak.rs +++ b/prover/src/tables/keccak.rs @@ -18,6 +18,7 @@ use executor::vm::instruction::execution::KECCAK_SYSCALL_NUMBER; use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -520,6 +521,69 @@ impl TransitionConstraint } } +impl Capture for KeccakAddressNoOverflowConstraint { + fn capture(&self, b: &mut IrBuilder) { + // mirrors compute(): addr_lo/addr_hi as little-endian byte combinations, + // ptr_lo/ptr_hi from the top-lane state_ptr halfwords, then the ADD-style + // carry chain, gated by mu. + let addr_0 = b.main(0, cols::addr(0)); + let addr_1 = b.main(0, cols::addr(1)); + let addr_2 = b.main(0, cols::addr(2)); + let addr_3 = b.main(0, cols::addr(3)); + let c256 = b.const_base(256); + let c65536 = b.const_base(65536); + let c16777216 = b.const_base(16777216); + + let t1 = b.mul(addr_1, c256); + let addr_lo = b.add(addr_0, t1); + let t2 = b.mul(addr_2, c65536); + let addr_lo = b.add(addr_lo, t2); + let t3 = b.mul(addr_3, c16777216); + let addr_lo = b.add(addr_lo, t3); + + let addr_4 = b.main(0, cols::addr(4)); + let addr_5 = b.main(0, cols::addr(5)); + let addr_6 = b.main(0, cols::addr(6)); + let addr_7 = b.main(0, cols::addr(7)); + + let t4 = b.mul(addr_5, c256); + let addr_hi = b.add(addr_4, t4); + let t5 = b.mul(addr_6, c65536); + let addr_hi = b.add(addr_hi, t5); + let t6 = b.mul(addr_7, c16777216); + let addr_hi = b.add(addr_hi, t6); + + let ptr_24_0 = b.main(0, cols::state_ptr(24, 0)); + let ptr_24_1 = b.main(0, cols::state_ptr(24, 1)); + let ptr_24_2 = b.main(0, cols::state_ptr(24, 2)); + let ptr_24_3 = b.main(0, cols::state_ptr(24, 3)); + + let t7 = b.mul(ptr_24_1, c65536); + let ptr_lo = b.add(ptr_24_0, t7); + let t8 = b.mul(ptr_24_3, c65536); + let ptr_hi = b.add(ptr_24_2, t8); + + let inv_2_32 = b.const_base(INV_SHIFT_32); + let c192 = b.const_base(192); + + // carry_0 = (addr_lo + 192 - ptr_lo) * inv_2_32 + let s = b.add(addr_lo, c192); + let s = b.sub(s, ptr_lo); + let carry_0 = b.mul(s, inv_2_32); + + // carry_1 = (addr_hi + carry_0 - ptr_hi) * inv_2_32 + let s2 = b.add(addr_hi, carry_0); + let s2 = b.sub(s2, ptr_hi); + let inv_2_32_2 = b.const_base(INV_SHIFT_32); + let carry_1 = b.mul(s2, inv_2_32_2); + + let mu = b.main(0, cols::MU); + let root = b.mul(mu, carry_1); + + b.emit(self.constraint_idx, root); + } +} + /// Create constraints for the KECCAK core chip. /// /// Per spec (keccak:c:state_ptr): ADD template for each lane: diff --git a/prover/src/tables/load.rs b/prover/src/tables/load.rs index 250d565b2..0ccca2338 100644 --- a/prover/src/tables/load.rs +++ b/prover/src/tables/load.rs @@ -25,6 +25,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -597,6 +598,75 @@ impl TransitionConstraint for LoadConstrai } } +impl Capture for LoadConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + let ff = b.const_base(255); + + let mu = b.main(0, cols::MU); + let read2 = b.main(0, cols::READ2); + let read4 = b.main(0, cols::READ4); + let read8 = b.main(0, cols::READ8); + let signed = b.main(0, cols::SIGNED); + let sign_bit = b.main(0, cols::SIGN_BIT); + + let root = match self.kind { + LoadConstraintKind::ReadImpliesMu => { + // (read2 + read4 + read8) * (1 - mu) + let read_sum = b.add(read2, read4); + let read_sum = b.add(read_sum, read8); + let one_minus_mu = b.sub(one, mu); + b.mul(read_sum, one_minus_mu) + } + LoadConstraintKind::ExtensionHigh(i) => { + // (1 - read8) * (res[i] - signed * sign_bit * 255) + let res_i = b.main(0, cols::RES[i]); + let signed_sign_bit = b.mul(signed, sign_bit); + let expected = b.mul(signed_sign_bit, ff); + let diff = b.sub(res_i, expected); + let one_minus_read8 = b.sub(one, read8); + b.mul(one_minus_read8, diff) + } + LoadConstraintKind::ExtensionMid(i) => { + // (1 - read4 - read8) * (res[i] - signed * sign_bit * 255) + let res_i = b.main(0, cols::RES[i]); + let signed_sign_bit = b.mul(signed, sign_bit); + let expected = b.mul(signed_sign_bit, ff); + let diff = b.sub(res_i, expected); + let coeff = b.sub(one, read4); + let coeff = b.sub(coeff, read8); + b.mul(coeff, diff) + } + LoadConstraintKind::ExtensionLow => { + // (1 - read2 - read4 - read8) * (res[1] - signed * sign_bit * 255) + let res_1 = b.main(0, cols::RES[1]); + let signed_sign_bit = b.mul(signed, sign_bit); + let expected = b.mul(signed_sign_bit, ff); + let diff = b.sub(res_1, expected); + let coeff = b.sub(one, read2); + let coeff = b.sub(coeff, read4); + let coeff = b.sub(coeff, read8); + b.mul(coeff, diff) + } + LoadConstraintKind::FlagIsBit(col) => { + // flag * (1 - flag) + let flag = b.main(0, col); + let one_minus_flag = b.sub(one, flag); + b.mul(flag, one_minus_flag) + } + LoadConstraintKind::WidthSumIsBit => { + // sum * (1 - sum), sum = read2 + read4 + read8 + let sum = b.add(read2, read4); + let sum = b.add(sum, read8); + let one_minus_sum = b.sub(one, sum); + b.mul(sum, one_minus_sum) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Creates all constraints for the LOAD table. pub fn constraints() -> Vec>> { diff --git a/prover/src/tables/lt.rs b/prover/src/tables/lt.rs index 02ed029bd..61f64e6b8 100644 --- a/prover/src/tables/lt.rs +++ b/prover/src/tables/lt.rs @@ -28,6 +28,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::TransitionConstraint; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -467,6 +468,52 @@ impl LtConstraint { (&rhs_hi + &sub_hi + &carry_0 - &lhs_hi) * &inv_2_32 } + /// Capture virtual carry[0], mirroring [`Self::compute_carry_0`]. + fn capture_carry_0(&self, b: &mut IrBuilder) -> Expr { + let lhs_0 = b.main(0, cols::LHS_0); + let rhs_0 = b.main(0, cols::RHS_0); + let sub_0 = b.main(0, cols::LHS_SUB_RHS_0); + let sub_1 = b.main(0, cols::LHS_SUB_RHS_1); + + let shift_16 = b.const_base(SHIFT_16); + let sub_1_shifted = b.mul(sub_1, shift_16); + let sub_lo = b.add(sub_0, sub_1_shifted); + + let inv_2_32 = b.const_base(crate::constraints::templates::INV_SHIFT_32); + let s = b.add(rhs_0, sub_lo); + let s = b.sub(s, lhs_0); + b.mul(s, inv_2_32) + } + + /// Capture virtual carry[1], mirroring [`Self::compute_carry_1`]. + fn capture_carry_1(&self, b: &mut IrBuilder) -> Expr { + let lhs_1 = b.main(0, cols::LHS_1); + let lhs_2 = b.main(0, cols::LHS_2); + let rhs_1 = b.main(0, cols::RHS_1); + let rhs_2 = b.main(0, cols::RHS_2); + let sub_2 = b.main(0, cols::LHS_SUB_RHS_2); + let sub_3 = b.main(0, cols::LHS_SUB_RHS_3); + + let shift_16 = b.const_base(SHIFT_16); + + let lhs_2_shifted = b.mul(lhs_2, shift_16); + let lhs_hi = b.add(lhs_1, lhs_2_shifted); + + let rhs_2_shifted = b.mul(rhs_2, shift_16); + let rhs_hi = b.add(rhs_1, rhs_2_shifted); + + let sub_3_shifted = b.mul(sub_3, shift_16); + let sub_hi = b.add(sub_2, sub_3_shifted); + + let carry_0 = self.capture_carry_0(b); + + let inv_2_32 = b.const_base(crate::constraints::templates::INV_SHIFT_32); + let s = b.add(rhs_hi, sub_hi); + let s = b.add(s, carry_0); + let s = b.sub(s, lhs_hi); + b.mul(s, inv_2_32) + } + /// Compute the constraint value. fn compute(&self, step: &TableView) -> FieldElement where @@ -562,6 +609,78 @@ impl TransitionConstraint for LtConstraint } } +impl Capture for LtConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + + let root = match self.kind { + LtConstraintKind::Carry0IsBit => { + // carry[0] * (1 - carry[0]) + let c0 = self.capture_carry_0(b); + let one_minus_c0 = b.sub(one, c0); + b.mul(c0, one_minus_c0) + } + LtConstraintKind::Carry1IsBit => { + // carry[1] * (1 - carry[1]) + let c1 = self.capture_carry_1(b); + let one_minus_c1 = b.sub(one, c1); + b.mul(c1, one_minus_c1) + } + LtConstraintKind::LtFormula => { + // lt = signed*(A*(1-B) + A*C + (1-B)*C) + (1-signed)*unsigned_lt + let lt = b.main(0, cols::LT); + let signed = b.main(0, cols::SIGNED); + let a = b.main(0, cols::LHS_MSB); + let bb = b.main(0, cols::RHS_MSB); + let c = self.capture_carry_1(b); + let unsigned_lt = c; + + // signed_lt = A*(1-B) + A*C + (1-B)*C + let one_minus_b = b.sub(one, bb); + let a_term = b.mul(a, one_minus_b); + let ac_term = b.mul(a, c); + let bc_term = b.mul(one_minus_b, c); + let signed_lt = b.add(a_term, ac_term); + let signed_lt = b.add(signed_lt, bc_term); + + // expected_lt = signed*signed_lt + (1-signed)*unsigned_lt + let signed_part = b.mul(signed, signed_lt); + let one_minus_signed = b.sub(one, signed); + let unsigned_part = b.mul(one_minus_signed, unsigned_lt); + let expected_lt = b.add(signed_part, unsigned_part); + + b.sub(lt, expected_lt) + } + LtConstraintKind::OutXorInvert => { + // out - (lt + invert - 2*lt*invert) + let out = b.main(0, cols::OUT); + let lt = b.main(0, cols::LT); + let invert = b.main(0, cols::INVERT); + let two = b.const_base(2); + let lt_invert = b.mul(lt, invert); + let two_lt_invert = b.mul(two, lt_invert); + let s = b.add(lt, invert); + let s = b.sub(s, two_lt_invert); + b.sub(out, s) + } + LtConstraintKind::InvertIsBit => { + // invert * (1 - invert) + let invert = b.main(0, cols::INVERT); + let one_minus_invert = b.sub(one, invert); + b.mul(invert, one_minus_invert) + } + LtConstraintKind::SignedIsBit => { + // signed * (1 - signed) + let signed = b.main(0, cols::SIGNED); + let one_minus_signed = b.sub(one, signed); + b.mul(signed, one_minus_signed) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Creates all constraints for the LT table. /// /// Returns: (constraints, next_constraint_idx) diff --git a/prover/src/tables/memw.rs b/prover/src/tables/memw.rs index 2b240747c..838094c34 100644 --- a/prover/src/tables/memw.rs +++ b/prover/src/tables/memw.rs @@ -31,6 +31,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -864,6 +865,22 @@ where mu_read + mu_write } +/// Capture virtual w2 = write2 + write4 + write8, mirroring [`compute_w2`]. +fn capture_w2(b: &mut IrBuilder) -> Expr { + let write2 = b.main(0, cols::WRITE2); + let write4 = b.main(0, cols::WRITE4); + let write8 = b.main(0, cols::WRITE8); + let s = b.add(write2, write4); + b.add(s, write8) +} + +/// Capture virtual μ_sum = μ_read + μ_write, mirroring [`compute_mu_sum`]. +fn capture_mu_sum(b: &mut IrBuilder) -> Expr { + let mu_read = b.main(0, cols::MU_READ); + let mu_write = b.main(0, cols::MU_WRITE); + b.add(mu_read, mu_write) +} + // ========================================================================= // Constraints (11 total: 2 custom + 2 IS_BIT for multiplicities + 7 IS_BIT for carry) // ========================================================================= @@ -940,6 +957,33 @@ impl TransitionConstraint for MemwConstrai } } +impl Capture for MemwConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + + let root = match self.kind { + MemwConstraintKind::MuSumIsBit => { + let mu_sum = capture_mu_sum(b); + let one_minus_mu_sum = b.sub(one, mu_sum); + b.mul(mu_sum, one_minus_mu_sum) + } + MemwConstraintKind::W2ImpliesMuSum => { + let w2 = capture_w2(b); + let mu_sum = capture_mu_sum(b); + let one_minus_mu_sum = b.sub(one, mu_sum); + b.mul(w2, one_minus_mu_sum) + } + MemwConstraintKind::WidthSumIsBit => { + let w2 = capture_w2(b); + let one_minus_w2 = b.sub(one, w2); + b.mul(w2, one_minus_w2) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Creates all constraints for the MEMW table. /// /// 15 constraints total: diff --git a/prover/src/tables/memw_aligned.rs b/prover/src/tables/memw_aligned.rs index 8042d9052..c52edcb09 100644 --- a/prover/src/tables/memw_aligned.rs +++ b/prover/src/tables/memw_aligned.rs @@ -36,6 +36,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -723,6 +724,42 @@ impl TransitionConstraint for MemwAlignedC } } +impl Capture for MemwAlignedConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + let mu_read = b.main(0, cols::MU_READ); + let mu_write = b.main(0, cols::MU_WRITE); + let mu_sum = b.add(mu_read, mu_write); + + let root = match self.kind { + MemwAlignedConstraintKind::MuSumIsBit => { + let one_minus_mu_sum = b.sub(one, mu_sum); + b.mul(mu_sum, one_minus_mu_sum) + } + MemwAlignedConstraintKind::W2ImpliesMuSum => { + let write2 = b.main(0, cols::WRITE2); + let write4 = b.main(0, cols::WRITE4); + let write8 = b.main(0, cols::WRITE8); + let w2 = b.add(write2, write4); + let w2 = b.add(w2, write8); + let one_minus_mu_sum = b.sub(one, mu_sum); + b.mul(w2, one_minus_mu_sum) + } + MemwAlignedConstraintKind::WidthSumIsBit => { + let write2 = b.main(0, cols::WRITE2); + let write4 = b.main(0, cols::WRITE4); + let write8 = b.main(0, cols::WRITE8); + let w2 = b.add(write2, write4); + let w2 = b.add(w2, write8); + let one_minus_w2 = b.sub(one, w2); + b.mul(w2, one_minus_w2) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Creates all constraints for the MEMW_A table (8 total). The last four are the /// spec's defense-in-depth width-flag assumptions. pub fn constraints() diff --git a/prover/src/tables/memw_register.rs b/prover/src/tables/memw_register.rs index 14a696cb9..171b1adfa 100644 --- a/prover/src/tables/memw_register.rs +++ b/prover/src/tables/memw_register.rs @@ -40,6 +40,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -403,6 +404,19 @@ impl TransitionConstraint for MemwRegister } } +impl Capture for MemwRegisterMuSumIsBit { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + let mu_read = b.main(0, cols::MU_READ); + let mu_write = b.main(0, cols::MU_WRITE); + let mu_sum = b.add(mu_read, mu_write); + // mu_sum * (1 - mu_sum) + let one_minus_mu_sum = b.sub(one, mu_sum); + let root = b.mul(mu_sum, one_minus_mu_sum); + b.emit(self.constraint_idx, root); + } +} + /// Creates all constraints for the MEMW_R table (3 total). /// /// - IS_BIT(MU_READ) -- unconditional diff --git a/prover/src/tables/mul.rs b/prover/src/tables/mul.rs index 33679211c..5f2925754 100644 --- a/prover/src/tables/mul.rs +++ b/prover/src/tables/mul.rs @@ -32,6 +32,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::TransitionConstraint; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -842,6 +843,70 @@ impl MulConstraint { raw_product - sum } + + /// Capture raw_product constraint for index `i`, mirroring + /// [`Self::compute_raw_product_constraint`]. The `for j`/`for k` loops are + /// bounded by `i` (compile-time), so they unroll into builder calls. + fn capture_raw_product_constraint(&self, i: usize, b: &mut IrBuilder) -> Expr { + let lhs: [Expr; 4] = [ + b.main(0, cols::LHS_0), + b.main(0, cols::LHS_1), + b.main(0, cols::LHS_2), + b.main(0, cols::LHS_3), + ]; + let rhs: [Expr; 4] = [ + b.main(0, cols::RHS_0), + b.main(0, cols::RHS_1), + b.main(0, cols::RHS_2), + b.main(0, cols::RHS_3), + ]; + let lhs_is_neg = b.main(0, cols::LHS_IS_NEGATIVE); + let rhs_is_neg = b.main(0, cols::RHS_IS_NEGATIVE); + + let sign_fill = b.const_base(SIGN_FILL); + let zero = b.const_base(0); + let mut lhs_ext: [Expr; 8] = [zero; 8]; + let mut rhs_ext: [Expr; 8] = [zero; 8]; + lhs_ext[..4].copy_from_slice(&lhs); + rhs_ext[..4].copy_from_slice(&rhs); + for j in 4..8 { + lhs_ext[j] = b.mul(sign_fill, lhs_is_neg); + rhs_ext[j] = b.mul(sign_fill, rhs_is_neg); + } + + let shift_16 = b.const_base(SHIFT_16); + let mut sum = zero; + + for k in 0..=1u32 { + let idx = 2 * i + k as usize; + if idx < 8 { + let mut inner_sum = zero; + for j in 0..=idx { + if j < 8 && (idx - j) < 8 { + let term = b.mul(lhs_ext[j], rhs_ext[idx - j]); + inner_sum = b.add(inner_sum, term); + } + } + if k == 0 { + sum = b.add(sum, inner_sum); + } else { + let scaled = b.mul(inner_sum, shift_16); + sum = b.add(sum, scaled); + } + } + } + + let raw_col = match i { + 0 => cols::RAW_PRODUCT_0, + 1 => cols::RAW_PRODUCT_1, + 2 => cols::RAW_PRODUCT_2, + 3 => cols::RAW_PRODUCT_3, + _ => unreachable!(), + }; + let raw_product = b.main(0, raw_col); + + b.sub(raw_product, sum) + } } impl TransitionConstraint for MulConstraint { @@ -871,6 +936,38 @@ impl TransitionConstraint for MulConstrain } } +impl Capture for MulConstraint { + fn capture(&self, b: &mut IrBuilder) { + let root = match self.kind { + MulConstraintKind::LhsSign => { + // (1 - lhs_signed) * lhs_is_negative + let lhs_signed = b.main(0, cols::LHS_SIGNED); + let lhs_is_neg = b.main(0, cols::LHS_IS_NEGATIVE); + let one = b.one(); + let one_minus_signed = b.sub(one, lhs_signed); + b.mul(one_minus_signed, lhs_is_neg) + } + MulConstraintKind::RhsSign => { + // (1 - rhs_signed) * rhs_is_negative + let rhs_signed = b.main(0, cols::RHS_SIGNED); + let rhs_is_neg = b.main(0, cols::RHS_IS_NEGATIVE); + let one = b.one(); + let one_minus_signed = b.sub(one, rhs_signed); + b.mul(one_minus_signed, rhs_is_neg) + } + MulConstraintKind::SignedIsBit(col) => { + // x * (1 - x) + let x = b.main(0, col); + let one = b.one(); + let one_minus_x = b.sub(one, x); + b.mul(x, one_minus_x) + } + MulConstraintKind::RawProduct(i) => self.capture_raw_product_constraint(i, b), + }; + b.emit(self.constraint_idx, root); + } +} + /// Creates all constraints for the MUL table. /// /// Returns: (constraints, next_constraint_idx) diff --git a/prover/src/tables/shift.rs b/prover/src/tables/shift.rs index 3115784f6..14a3f8355 100644 --- a/prover/src/tables/shift.rs +++ b/prover/src/tables/shift.rs @@ -19,6 +19,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::constraints::transition::TransitionConstraint; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -839,6 +840,92 @@ impl ShiftConstraint { left_part + right_part } + /// Capture the `shifted` virtual column at index `half_idx` (0..4), + /// mirroring [`Self::compute_shifted_half`]. + fn capture_shifted_half(half_idx: usize, b: &mut IrBuilder) -> Expr { + let dir = b.main(0, cols::DIRECTION); + let mu = b.main(0, cols::MU); + let left = b.sub(mu, dir); // mu - direction + let right = dir; + + // extension = 65535 * is_negative + let is_neg = b.main(0, cols::IS_NEGATIVE); + let shift_fill = b.const_base(65535u64); + let extension = b.mul(is_neg, shift_fill); + + let get_x = |i: usize, b: &mut IrBuilder| b.main(0, cols::X[i]); + let get_y = |i: usize, b: &mut IrBuilder| b.main(0, cols::Y[i]); + let get_ls = |i: usize, b: &mut IrBuilder| -> Expr { + if i < 3 { + b.main(0, cols::LIMB_SHIFT_RAW[i]) + } else { + // limb_shift[3] is virtual: 1 - ls_raw[0] - ls_raw[1] - ls_raw[2] + let one = b.one(); + let ls0 = b.main(0, cols::LIMB_SHIFT_RAW[0]); + let ls1 = b.main(0, cols::LIMB_SHIFT_RAW[1]); + let ls2 = b.main(0, cols::LIMB_SHIFT_RAW[2]); + let r = b.sub(one, ls0); + let r = b.sub(r, ls1); + b.sub(r, ls2) + } + }; + + // intra_limb_left[i]: X[0] for i=0, X[i]+Y[i-1] for i>0 + let intra_left = |i: usize, b: &mut IrBuilder| -> Expr { + if i == 0 { + get_x(0, b) + } else { + let xi = get_x(i, b); + let yi1 = get_y(i - 1, b); + b.add(xi, yi1) + } + }; + + // intra_limb_right[i]: Y[i]+X[i+1] + let intra_right = |i: usize, b: &mut IrBuilder| -> Expr { + let yi = get_y(i, b); + let xi1 = get_x(i + 1, b); + b.add(yi, xi1) + }; + + let i = half_idx; + let zero = b.const_base(0); + + // left_part = left * Σ_j=0^i limb_shift[j] * intra_limb_left[i-j] + let mut left_part = zero; + for j in 0..=i { + let ls_j = get_ls(j, b); + let il = intra_left(i - j, b); + let term = b.mul(ls_j, il); + left_part = b.add(left_part, term); + } + let left_part = b.mul(left, left_part); + + // right_shift_part = right * Σ_j=0^(3-i) limb_shift[j] * intra_limb_right[i+j] + let mut right_shift_part = zero; + for j in 0..=(3 - i) { + let ls_j = get_ls(j, b); + let ir = intra_right(i + j, b); + let term = b.mul(ls_j, ir); + right_shift_part = b.add(right_shift_part, term); + } + + // right_ext_part = right * extension * Σ_j=(4-i)^3 limb_shift[j] + let mut ext_sum = zero; + if i < 4 { + for j in (4 - i)..4 { + let ls_j = get_ls(j, b); + ext_sum = b.add(ext_sum, ls_j); + } + } + let right_ext_part = b.mul(extension, ext_sum); + + let right_inner = b.add(right_shift_part, right_ext_part); + let right_part = b.mul(right, right_inner); + + b.add(left_part, right_part) + } + fn compute(&self, step: &TableView) -> FieldElement where F: IsSubFieldOf, @@ -937,6 +1024,85 @@ impl TransitionConstraint for ShiftConstra } } +impl Capture for ShiftConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + let shift_16 = b.const_base(SHIFT_16); + + let root = match self.kind { + ShiftConstraintKind::DirectionImpliesMu => { + // direction * (1 - mu) + let dir = b.main(0, cols::DIRECTION); + let mu = b.main(0, cols::MU); + let one_minus_mu = b.sub(one, mu); + b.mul(dir, one_minus_mu) + } + ShiftConstraintKind::ZbsOverrideX(i) => { + // zbs * (X[i] - in[i] * left), left = mu - direction + let zbs = b.main(0, cols::ZBS); + let x_i = b.main(0, cols::X[i]); + let in_i = b.main(0, cols::IN[i]); + let mu = b.main(0, cols::MU); + let dir = b.main(0, cols::DIRECTION); + let left = b.sub(mu, dir); + let in_left = b.mul(in_i, left); + let diff = b.sub(x_i, in_left); + b.mul(zbs, diff) + } + ShiftConstraintKind::ZbsOverrideX4 => { + // zbs * X[4] + let zbs = b.main(0, cols::ZBS); + let x4 = b.main(0, cols::X_4); + b.mul(zbs, x4) + } + ShiftConstraintKind::ZbsOverrideY(i) => { + // zbs * (Y[i] - in[i] * right), right = direction + let zbs = b.main(0, cols::ZBS); + let y_i = b.main(0, cols::Y[i]); + let in_i = b.main(0, cols::IN[i]); + let dir = b.main(0, cols::DIRECTION); + let in_dir = b.mul(in_i, dir); + let diff = b.sub(y_i, in_dir); + b.mul(zbs, diff) + } + ShiftConstraintKind::LimbShiftIsBit(i) => { + // limb_shift[i] * (1 - limb_shift[i]) + let ls = if i < 3 { + b.main(0, cols::LIMB_SHIFT_RAW[i]) + } else { + let ls0 = b.main(0, cols::LIMB_SHIFT_RAW[0]); + let ls1 = b.main(0, cols::LIMB_SHIFT_RAW[1]); + let ls2 = b.main(0, cols::LIMB_SHIFT_RAW[2]); + let r = b.sub(one, ls0); + let r = b.sub(r, ls1); + b.sub(r, ls2) + }; + let one_minus_ls = b.sub(one, ls); + b.mul(ls, one_minus_ls) + } + ShiftConstraintKind::OutputMatchesShifted(i) => { + // out[i] - (shifted::DWordWL)[i] + // (shifted::DWordWL)[i] = shifted[2*i] + shifted[2*i+1] * 2^16 + let out_col = if i == 0 { cols::OUT_0 } else { cols::OUT_1 }; + let out = b.main(0, out_col); + let half_lo = Self::capture_shifted_half(2 * i, b); + let half_hi = Self::capture_shifted_half(2 * i + 1, b); + let half_hi_shifted = b.mul(half_hi, shift_16); + let r = b.sub(out, half_lo); + b.sub(r, half_hi_shifted) + } + ShiftConstraintKind::FlagIsBit(col) => { + // flag * (1 - flag) + let flag = b.main(0, col); + let one_minus_flag = b.sub(one, flag); + b.mul(flag, one_minus_flag) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Number of polynomial constraints in the SHIFT table. // 1 (DirectionImpliesMu) + 4 (ZbsOverrideX) + 1 (ZbsOverrideX4) + 4 (ZbsOverrideY) // + 4 (LimbShiftIsBit) + 2 (OutputMatchesShifted) + 3 (FlagIsBit) = 19 diff --git a/prover/src/tables/store.rs b/prover/src/tables/store.rs index 1cdf0334e..4009fa504 100644 --- a/prover/src/tables/store.rs +++ b/prover/src/tables/store.rs @@ -21,6 +21,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; @@ -308,6 +309,31 @@ impl TransitionConstraint for StoreConstra } } +impl Capture for StoreConstraint { + fn capture(&self, b: &mut IrBuilder) { + let w2 = b.main(0, cols::WRITE2); + let w4 = b.main(0, cols::WRITE4); + let w8 = b.main(0, cols::WRITE8); + let sum = b.add(w2, w4); + let sum = b.add(sum, w8); + let one = b.one(); + + let root = match self.kind { + StoreConstraintKind::WidthSumIsBit => { + let one_minus_sum = b.sub(one, sum); + b.mul(sum, one_minus_sum) + } + StoreConstraintKind::WidthImpliesMu => { + let mu = b.main(0, cols::MU); + let one_minus_mu = b.sub(one, mu); + b.mul(sum, one_minus_mu) + } + }; + + b.emit(self.constraint_idx, root); + } +} + /// Creates all transition constraints for the STORE table: `IS_BIT` on each /// width flag, the width-sum-is-bit constraint, and width ⇒ μ. pub fn store_constraints( diff --git a/prover/src/test_utils.rs b/prover/src/test_utils.rs index fd9d9d40c..32494e6b7 100644 --- a/prover/src/test_utils.rs +++ b/prover/src/test_utils.rs @@ -139,7 +139,7 @@ where /// With zero bus interactions, `AirWithBuses::new` appends no LogUp constraints /// and allocates no aux columns, so `validate_trace` evaluates exactly the chip's /// transition constraints over a main-only trace. -pub fn busless_air + 'static>( +pub fn busless_air + stark::constraint_ir::Capture + 'static>( num_columns: usize, constraints: Vec, ) -> VmAir { diff --git a/prover/src/tests/constraint_ir_tests.rs b/prover/src/tests/constraint_ir_tests.rs index 86bf51d81..226f67412 100644 --- a/prover/src/tests/constraint_ir_tests.rs +++ b/prover/src/tests/constraint_ir_tests.rs @@ -51,7 +51,7 @@ where { let mut b = IrBuilder::new(); c.capture(&mut b); - let prog = b.finish(); + let prog = b.finish(0); eprintln!("[{label}] captured {} IR nodes", prog.len()); let mut rng = SplitMix64::new(0xDEAD_BEEF_CAFE_F00D ^ (label.len() as u64)); @@ -67,7 +67,7 @@ where c.evaluate::(&real_step); // IR interpreter over the same row. - let got = eval_program_base(&prog, &row); + let got = eval_program_base(&prog, c.constraint_idx(), &row); assert_eq!( real, got, @@ -111,3 +111,256 @@ fn test_ir_matches_product_zero() { let c = ProductZeroConstraint::new(12, 17, 0); assert_ir_matches_evaluate(&c, "product_zero"); } + +// ============================================================================= +// Phase 1 GATE: full-table, full-program differential test (CPU, LogUp-heavy). +// +// `create_cpu_air` assembles every algebraic CPU constraint AND, via +// `AirWithBuses::new`, the 2 LogUp constraints for its bus interactions +// (DECODE/ALU/MEMORY/CPU32/MEMW/BRANCH/ECALL). Capturing its full program and +// interpreting it over a real LDE must reproduce `air.compute_transition_prover` +// (prover) and `air.compute_transition` (verifier, at the OOD point) bit-for-bit. +// ============================================================================= +mod full_table_gate { + use crate::tables::cpu::{CpuOperation, generate_cpu_trace}; + use crate::tables::eq::{EqOperation, generate_eq_trace}; + use crate::tables::types::DecodeEntry; + use crate::test_utils::{VmAir, create_cpu_air, create_eq_air}; + + use crypto::fiat_shamir::default_transcript::DefaultTranscript; + use executor::vm::instruction::decoding::{ArithOp, Instruction}; + use executor::vm::logs::Log; + use math::field::element::FieldElement; + use stark::constraint_ir::{eval_program, eval_program_verifier}; + use stark::frame::Frame; + use stark::proof::options::ProofOptions; + use stark::table::TableView; + use stark::trace::{LDETraceTable, TraceTable}; + use stark::traits::{AIR, TransitionEvaluationContext}; + + use super::{GoldilocksExtension, GoldilocksField}; + + const PC: u64 = 0x1000; + + /// Build a `CpuOperation` from an instruction + register values (mirrors + /// `prover/src/tests/cpu_tests.rs::op_of`, duplicated here to keep this + /// gate test self-contained). + fn op_of(instr: Instruction, src1: u64, src2: u64, dst: u64, next_pc: u64) -> CpuOperation { + let decode = DecodeEntry::from_instruction(PC, instr, 4); + let log = Log { + current_pc: PC, + next_pc, + src1_val: src1, + src2_val: src2, + dst_val: dst, + }; + CpuOperation::from_log(&log, 4, decode) + } + + /// A handful of real CPU operations exercising different bus interactions + /// (ALU add, ALU sub) so the captured LogUp program sees non-trivial + /// fingerprints/multiplicities on every row, not just padding zeros. + fn sample_operations() -> Vec { + vec![ + op_of( + Instruction::Arith { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, + }, + 10, + 20, + 30, + PC + 4, + ), + op_of( + Instruction::Arith { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Sub, + }, + 50, + 20, + 30, + PC + 4, + ), + ] + } + + /// A handful of real EQ operations (BEQ-style and BNE-style) so the + /// captured LogUp program (1 batched pair + 1 absorbed interaction, the + /// branch not exercised by the CPU table below) sees non-trivial + /// fingerprints on every row. + fn sample_eq_operations() -> Vec { + vec![ + EqOperation::new(42, 42, false), + EqOperation::new(7, 9, true), + ] + } + + #[test] + fn test_cpu_table_full_program_matches_boxed_path_prover_and_verifier() { + let air = create_cpu_air(&ProofOptions::default_test_options()); + let trace = generate_cpu_trace(&sample_operations()); + assert_full_table_ir_matches_boxed_path(&air, trace, "cpu_full_table"); + } + + #[test] + fn test_eq_table_full_program_matches_boxed_path_prover_and_verifier() { + let air = create_eq_air(&ProofOptions::default_test_options()); + let trace = generate_eq_trace(&sample_eq_operations()); + assert_full_table_ir_matches_boxed_path(&air, trace, "eq_full_table"); + } + + /// Capture `air`'s full program, then for every prover-side LDE row and one + /// verifier-side OOD point, assert the IR interpreter reproduces the boxed + /// `compute_transition_prover`/`compute_transition` path bit-for-bit. + fn assert_full_table_ir_matches_boxed_path( + air: &VmAir, + mut trace: TraceTable, + label: &str, + ) { + // Build the aux (LogUp) trace + rap challenges, exactly as the prover + // pipeline would (minus the surrounding LDE/FRI machinery, which the + // constraint evaluator doesn't touch). + let mut transcript = DefaultTranscript::::new(&[]); + let rap_challenges = air.build_rap_challenges(&mut transcript); + air.build_auxiliary_trace(&mut trace, &rap_challenges); + + let num_rows = trace.num_rows(); + assert!(num_rows >= 2, "need >=2 rows for the LogUp next-row read"); + + let main_columns: Vec>> = (0..trace.num_main_columns) + .map(|col| { + (0..num_rows) + .map(|row| *trace.main_table.get(row, col)) + .collect() + }) + .collect(); + let aux_columns: Vec>> = (0..trace.num_aux_columns) + .map(|col| { + (0..num_rows) + .map(|row| *trace.aux_table.get(row, col)) + .collect() + }) + .collect(); + let lde_trace = + LDETraceTable::from_columns(main_columns.clone(), aux_columns.clone(), 1, 1); + + let prog = air.constraint_program(); + eprintln!( + "[{label}] captured {} IR nodes, {} constraints (num_base={})", + prog.len(), + prog.roots.len(), + prog.num_base + ); + assert_eq!( + prog.roots.len(), + air.num_transition_constraints(), + "every constraint_idx must have been emitted" + ); + assert_eq!(prog.num_base, air.num_base_transition_constraints()); + + let num_base = air.num_base_transition_constraints(); + let num_transition = air.num_transition_constraints(); + let no_periodic: Vec> = Vec::new(); + let logup_alpha_powers = { + // Mirrors `ConstraintEvaluator::evaluate_transitions`'s alpha-power + // precompute (`compute_alpha_powers`, crate-private in `stark`): + // [1, alpha, alpha^2, ...], rap_challenges[1] is alpha. + use stark::lookup::LOGUP_CHALLENGE_ALPHA; + if rap_challenges.len() > LOGUP_CHALLENGE_ALPHA { + let alpha = &rap_challenges[LOGUP_CHALLENGE_ALPHA]; + let count = air.max_bus_elements(); + let mut powers = Vec::with_capacity(count); + let mut cur = FieldElement::::one(); + for _ in 0..count { + powers.push(cur); + cur *= alpha; + } + powers + } else { + Vec::new() + } + }; + let logup_table_offset = FieldElement::::zero(); + let packing_shifts = stark::lookup::PackingShifts::::new(); + + // --- Prover-side: every row, boxed path vs IR interpreter --- + let offsets = &air.context().transition_offsets; + for step in 0..lde_trace.num_steps() { + let frame: Frame = + Frame::read_step_from_lde(&lde_trace, step, offsets); + let ctx = TransitionEvaluationContext::new_prover( + &frame, + &no_periodic, + &rap_challenges, + &logup_alpha_powers, + &logup_table_offset, + &packing_shifts, + ); + + let mut boxed_base = vec![FieldElement::::zero(); num_base]; + let mut boxed_ext = vec![FieldElement::::zero(); num_transition]; + air.compute_transition_prover(&ctx, &mut boxed_base, &mut boxed_ext); + + let mut ir_base = vec![FieldElement::::zero(); num_base]; + let mut ir_ext = vec![FieldElement::::zero(); num_transition]; + eval_program(&prog, &ctx, &mut ir_base, &mut ir_ext); + + assert_eq!(boxed_base, ir_base, "base evals mismatch at step {step}"); + assert_eq!( + boxed_ext[num_base..], + ir_ext[num_base..], + "ext (LogUp) evals mismatch at step {step}" + ); + } + + // --- Verifier-side: at one "OOD" point, boxed path vs IR interpreter --- + // The verifier frame holds only extension-field elements; embed the + // same real row data (rows 0 and 1, matching transition_offsets=[0,1]) + // into GoldilocksExtension to build it, exactly as `evaluate_zerofier`'s + // sibling machinery would after a real FRI opening. + let embed_row = |row: usize| -> ( + Vec>, + Vec>, + ) { + let main: Vec<_> = main_columns + .iter() + .map(|col| col[row].to_extension()) + .collect(); + let aux: Vec<_> = aux_columns.iter().map(|col| col[row]).collect(); + (main, aux) + }; + let (main0, aux0) = embed_row(0); + let (main1, aux1) = embed_row(1); + let verifier_frame: Frame = Frame::new(vec![ + TableView::new(vec![main0], vec![aux0]), + TableView::new(vec![main1], vec![aux1]), + ]); + let no_periodic_ext: Vec> = no_periodic + .iter() + .map(|x: &FieldElement| (*x).to_extension()) + .collect(); + let verifier_packing_shifts = stark::lookup::PackingShifts::::new(); + let verifier_ctx = TransitionEvaluationContext::new_verifier( + &verifier_frame, + &no_periodic_ext, + &rap_challenges, + &logup_alpha_powers, + &logup_table_offset, + &verifier_packing_shifts, + ); + + let boxed_verifier = air.compute_transition(&verifier_ctx); + let mut ir_verifier = vec![FieldElement::::zero(); num_transition]; + eval_program_verifier(&prog, &verifier_ctx, &mut ir_verifier); + + assert_eq!( + boxed_verifier, ir_verifier, + "verifier evals mismatch at OOD point" + ); + } +} From 1c9234252d957feed3214c36a005ea2aea1eec78 Mon Sep 17 00:00:00 2001 From: MauroFab Date: Tue, 30 Jun 2026 16:55:17 -0300 Subject: [PATCH 5/6] =?UTF-8?q?spike(stark):=20Phase=202=20=E2=80=94=20wir?= =?UTF-8?q?e=20constraint-ir=20interpreter=20into=20prover/verifier=20behi?= =?UTF-8?q?nd=20a=20feature?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the `constraint-ir` Cargo feature to crypto/stark. Behind it, the prover (ConstraintEvaluator::evaluate_transitions) and verifier (IsStarkVerifier::step_2_verify_claimed_composition_polynomial) interpret the captured ConstraintProgram in place of the boxed Vec> dispatch loop. The boxed path stays the default and the differential oracle (falls back automatically if the type tower isn't lambda_vm's Goldilocks base + degree-3 extension). - constraint_ir/bridge.rs: the generic Field/FieldExtension -> concrete Goldilocks TypeId seam (mirrors crate::gpu_lde's transmute pattern, using plain `transmute` on the reference/slice types rather than transmute_copy, since references are always pointer-sized regardless of the generic pointee). - constraints/evaluator.rs: ConstraintEvaluator caches the AIR::constraint_program() once in `new()` (feature-gated field); evaluate_transitions swaps the per-row compute_transition_prover call for try_eval_program_prover, falling back to the boxed call if it returns false. Requires 'static on the impl block (TypeId needs it). - verifier.rs: step_2_verify_claimed_composition_polynomial's transition evaluation is split into a tiny step_2_compute_transitions helper with two #[cfg] sibling bodies (boxed vs IR-interpreted-with-fallback); the surrounding boundary/zerofier/composition accounting is untouched. IsStarkVerifier and its sole impl gain a 'static bound (needed transitively for the TypeId check; safe since there is one implementor). Gate: cargo test -p lambda-vm-prover --release --features stark/constraint-ir passes all 430 tests (incl. every test_prove_elfs_* real end-to-end prove->verify), identical pass count to the default boxed path. fmt/clippy clean on both feature states; full workspace build unaffected. --- crypto/stark/Cargo.toml | 5 ++ crypto/stark/src/constraint_ir/bridge.rs | 99 +++++++++++++++++++++++ crypto/stark/src/constraint_ir/mod.rs | 24 ++++-- crypto/stark/src/constraints/evaluator.rs | 28 ++++++- crypto/stark/src/verifier.rs | 39 +++++++-- 5 files changed, 180 insertions(+), 15 deletions(-) create mode 100644 crypto/stark/src/constraint_ir/bridge.rs diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index d0f6a51ef..01c5fe5d6 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -54,6 +54,11 @@ cuda = ["dep:math-cuda"] test-cuda-faults = ["cuda", "math-cuda/test-faults"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] disk-spill = ["dep:memmap2", "dep:tempfile", "dep:libc", "crypto/disk-spill"] +# Swaps the captured constraint_ir interpreter into the prover/verifier +# transition-constraint evaluation, in place of the boxed +# Vec> dispatch loop. The boxed path +# stays the default + differential oracle (see crypto/stark/src/constraint_ir/). +constraint-ir = [] [package.metadata.wasm-pack.profile.dev] diff --git a/crypto/stark/src/constraint_ir/bridge.rs b/crypto/stark/src/constraint_ir/bridge.rs new file mode 100644 index 000000000..6c202fd99 --- /dev/null +++ b/crypto/stark/src/constraint_ir/bridge.rs @@ -0,0 +1,99 @@ +//! Generic-`Field`/`FieldExtension` → concrete-Goldilocks TypeId seam. +//! +//! `eval_program`/`eval_program_verifier` are concretely typed to +//! `GoldilocksField`/`Degree3GoldilocksExtensionField` (the IR is single-field, +//! see `crate::constraint_ir`), but the prover/verifier's evaluation loops +//! (`crate::constraints::evaluator::ConstraintEvaluator`, +//! `crate::verifier::verify`) are generic over `Field: IsSubFieldOf`. +//! `try_eval_program_prover`/`try_eval_program_verifier` bridge the two: a +//! `TypeId` check establishes `Field == GoldilocksField` and +//! `FieldExtension == Degree3GoldilocksExtensionField` exactly, after which a +//! `&TransitionEvaluationContext` is reinterpreted as +//! `&TransitionEvaluationContext` +//! (same layout — only the type parameters differ, and the check pins them to +//! be the same concrete type), mirroring the seam already used by +//! `crate::gpu_lde` (`TypeId::of::()` guards + `transmute_copy`). +//! +//! Returns `false` (no-op on the caller's buffers) when the TypeId check +//! fails, so callers fall back to the boxed path unconditionally outside the +//! lambda_vm single-field (Goldilocks base + degree-3 extension) setup. + +use std::any::TypeId; +use std::mem::transmute; + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField as GoldilocksExtension; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; + +use super::interp::{eval_program, eval_program_verifier}; +use super::ir::ConstraintProgram; +use crate::traits::TransitionEvaluationContext; + +/// `true` iff `Field == GoldilocksField` and `FieldExtension == Degree3GoldilocksExtensionField`. +#[inline] +fn is_goldilocks_tower() -> bool { + TypeId::of::() == TypeId::of::() + && TypeId::of::() == TypeId::of::() +} + +/// Prover-side bridge: interpret `prog` via [`eval_program`] in place of the +/// boxed `air.compute_transition_prover(...)` call, writing the same +/// `base_evals`/`ext_evals` contract. Returns `true` if it ran (the type +/// tower matched Goldilocks); `false` otherwise (caller should fall back). +pub fn try_eval_program_prover( + prog: &ConstraintProgram, + ctx: &TransitionEvaluationContext, + base_evals: &mut [FieldElement], + ext_evals: &mut [FieldElement], +) -> bool +where + Field: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + FieldExtension: IsField + Send + Sync + 'static, +{ + if !is_goldilocks_tower::() { + return false; + } + // SAFETY: the TypeId check above establishes `Field == GoldilocksField` + // and `FieldExtension == Degree3GoldilocksExtensionField` exactly, so + // `TransitionEvaluationContext` and + // `[FieldElement]`/`[FieldElement]` have the same + // layout as their Goldilocks-concrete counterparts (same generic struct, + // same — now proven identical — type arguments). Mirrors the + // `transmute_copy` seam in `crate::gpu_lde`. + let ctx_gl: &TransitionEvaluationContext = + unsafe { transmute(ctx) }; + let base_gl: &mut [FieldElement] = unsafe { transmute(base_evals) }; + let ext_gl: &mut [FieldElement] = unsafe { transmute(ext_evals) }; + eval_program(prog, ctx_gl, base_gl, ext_gl); + true +} + +/// Verifier-side bridge: interpret `prog` via [`eval_program_verifier`] in +/// place of the boxed `air.compute_transition(...)` call, writing the same +/// `ext_evals` contract. Returns `true` if it ran; `false` otherwise. +/// +/// At the OOD point the verifier's `TransitionEvaluationContext` is always +/// `` (`Field` and `FieldExtension` are the same type — see +/// `TransitionEvaluationContext::Verifier`, which has no base-field data), so +/// unlike [`try_eval_program_prover`] this only needs `Field: IsSubFieldOf` +/// (reflexive for any field) and not `IsFFTField`. +pub fn try_eval_program_verifier( + prog: &ConstraintProgram, + ctx: &TransitionEvaluationContext, + ext_evals: &mut [FieldElement], +) -> bool +where + Field: IsSubFieldOf + Send + Sync + 'static, + FieldExtension: IsField + Send + Sync + 'static, +{ + if !is_goldilocks_tower::() { + return false; + } + // SAFETY: see `try_eval_program_prover`. + let ctx_gl: &TransitionEvaluationContext = + unsafe { transmute(ctx) }; + let ext_gl: &mut [FieldElement] = unsafe { transmute(ext_evals) }; + eval_program_verifier(prog, ctx_gl, ext_gl); + true +} diff --git a/crypto/stark/src/constraint_ir/mod.rs b/crypto/stark/src/constraint_ir/mod.rs index f904d51f1..c60af8188 100644 --- a/crypto/stark/src/constraint_ir/mod.rs +++ b/crypto/stark/src/constraint_ir/mod.rs @@ -1,24 +1,32 @@ -//! Explicit-builder constraint capture spike (Plan B). +//! Explicit-builder constraint capture (Plan B). //! -//! Proof-of-concept that lambda_vm's algebraic transition constraints can be -//! captured into a flat, single-field Goldilocks IR via an explicit -//! [`IrBuilder`] (rather than the recording "symbolic field" of Plan A), and -//! that interpreting that IR on the CPU reproduces the constraint's real -//! `evaluate` bit-for-bit. +//! Every transition constraint is captured once, at AIR-construction time, +//! into a flat single-field Goldilocks IR ([`ConstraintProgram`]) via an +//! explicit [`IrBuilder`] (rather than the recording "symbolic field" of Plan +//! A). Interpreting that IR on the CPU reproduces the constraint's real +//! `evaluate`/`compute` body bit-for-bit — including the LogUp framework +//! constraints (`crypto/stark/src/lookup.rs`). //! //! Both plans produce the SAME IR and use the SAME interpreter; they differ //! only in the capture front-end. Here each constraint implements [`Capture`] -//! and translates its `evaluate` body into builder calls. This is CPU-only and -//! does not touch the prover hot loop, the LogUp framework, or GPU code. +//! and translates its `evaluate` body into builder calls. +//! +//! Behind the `constraint-ir` Cargo feature, [`bridge`] swaps the interpreter +//! into the prover (`constraints/evaluator.rs`) and verifier (`verifier.rs`) +//! hot paths, in place of the boxed `Vec>` +//! dispatch loop. The boxed path stays the default and the differential oracle. //! //! - [`ir`]: the IR data structures ([`ConstraintProgram`], [`Op`], [`Dim`]). //! - [`builder`]: the [`IrBuilder`] and [`Expr`] capture API. //! - [`interp`]: a CPU forward-pass interpreter over the IR. +//! - [`bridge`]: the generic-`Field`/`FieldExtension` → concrete-Goldilocks +//! TypeId seam used to call the interpreter from the generic prover/verifier. //! //! [`ConstraintProgram`]: ir::ConstraintProgram //! [`Op`]: ir::Op //! [`Dim`]: ir::Dim +pub mod bridge; pub mod builder; pub mod interp; pub mod ir; diff --git a/crypto/stark/src/constraints/evaluator.rs b/crypto/stark/src/constraints/evaluator.rs index 6e94473b7..264c8abd8 100644 --- a/crypto/stark/src/constraints/evaluator.rs +++ b/crypto/stark/src/constraints/evaluator.rs @@ -21,12 +21,17 @@ pub struct ConstraintEvaluator< > { boundary_constraints: BoundaryConstraints, logup_table_offset: FieldElement, + /// Captured once per proof (behind the `constraint-ir` feature): the flat + /// IR program for every transition constraint, interpreted in place of + /// the boxed dispatch loop. See `crate::constraint_ir`. + #[cfg(feature = "constraint-ir")] + constraint_program: crate::constraint_ir::ConstraintProgram, phantom: PhantomData<(Field, PI)>, } impl ConstraintEvaluator where - Field: IsSubFieldOf + IsFFTField + Send + Sync, - FieldExtension: Send + Sync + IsField, + Field: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + FieldExtension: Send + Sync + IsField + 'static, { /// Evaluate transition + boundary constraints across the entire LDE domain. /// @@ -45,6 +50,8 @@ where num_periodic: usize, offsets: &[usize], logup_table_offset: &FieldElement, + #[cfg(feature = "constraint-ir")] + constraint_program: &crate::constraint_ir::ConstraintProgram, ) -> Vec> { let is_uniform = zerofier_data.is_uniform(); let num_base = air.num_base_transition_constraints(); @@ -97,6 +104,19 @@ where logup_table_offset, &packing_shifts, ); + #[cfg(feature = "constraint-ir")] + { + let ran = crate::constraint_ir::bridge::try_eval_program_prover( + constraint_program, + &ctx, + base_buf, + transition_buf, + ); + if !ran { + air.compute_transition_prover(&ctx, base_buf, transition_buf); + } + } + #[cfg(not(feature = "constraint-ir"))] air.compute_transition_prover(&ctx, base_buf, transition_buf); let acc_transition = if is_uniform { @@ -209,6 +229,8 @@ where Self { boundary_constraints, logup_table_offset, + #[cfg(feature = "constraint-ir")] + constraint_program: air.constraint_program(), phantom: PhantomData::<(Field, PI)> {}, } } @@ -313,6 +335,8 @@ where num_periodic, offsets, &self.logup_table_offset, + #[cfg(feature = "constraint-ir")] + &self.constraint_program, ) } } diff --git a/crypto/stark/src/verifier.rs b/crypto/stark/src/verifier.rs index 68819c76b..705eb4811 100644 --- a/crypto/stark/src/verifier.rs +++ b/crypto/stark/src/verifier.rs @@ -40,8 +40,8 @@ pub struct Verifier< } impl< - Field: IsSubFieldOf + IsFFTField + Send + Sync, - FieldExtension: IsField + Send + Sync, + Field: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + FieldExtension: IsField + Send + Sync + 'static, PI, > IsStarkVerifier for Verifier { @@ -78,8 +78,8 @@ pub type DeepPolynomialEvaluations = (Vec>, Vec + IsFFTField + Send + Sync, - FieldExtension: Send + Sync + IsField, + Field: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + FieldExtension: Send + Sync + IsField + 'static, PI, > { @@ -206,7 +206,7 @@ pub trait IsStarkVerifier< &packing_shifts, ); let transition_ood_frame_evaluations = - air.compute_transition(&transition_evaluation_context); + Self::step_2_compute_transitions(air, &transition_evaluation_context); let mut denominators = vec![FieldElement::::zero(); air.num_transition_constraints()]; @@ -238,6 +238,35 @@ pub trait IsStarkVerifier< composition_poly_claimed_ood_evaluation == composition_poly_ood_evaluation } + /// Computes the transition-constraint evaluations at the OOD point: the + /// boxed `air.compute_transition(...)` dispatch by default, or (behind + /// the `constraint-ir` feature) the captured IR program interpreted via + /// `crate::constraint_ir::bridge`, falling back to the boxed path if the + /// type tower isn't the lambda_vm Goldilocks one. + #[cfg(not(feature = "constraint-ir"))] + fn step_2_compute_transitions( + air: &dyn AIR, + ctx: &TransitionEvaluationContext, + ) -> Vec> { + air.compute_transition(ctx) + } + + #[cfg(feature = "constraint-ir")] + fn step_2_compute_transitions( + air: &dyn AIR, + ctx: &TransitionEvaluationContext, + ) -> Vec> { + let prog = air.constraint_program(); + let mut evals = + vec![FieldElement::::zero(); air.num_transition_constraints()]; + let ran = crate::constraint_ir::bridge::try_eval_program_verifier(&prog, ctx, &mut evals); + if ran { + evals + } else { + air.compute_transition(ctx) + } + } + /// Reconstructs the Deep composition polynomial evaluations at the challenge indices values using the provided /// openings of the trace polynomials and the composition polynomial parts. It then uses these to verify that the /// FRI decommitments are valid and correspond to the Deep composition polynomial. From 7821c612c93e8a5f3edd487570f93e52ae8a182a Mon Sep 17 00:00:00 2001 From: MauroFab Date: Tue, 30 Jun 2026 18:55:28 -0300 Subject: [PATCH 6/6] =?UTF-8?q?spike(stark):=20close=20constraint-ir=20rev?= =?UTF-8?q?iew=20gaps=20=E2=80=94=202-absorbed=20LogUp=20coverage,=20all?= =?UTF-8?q?=20Packing=20variants,=20no-panic=20fallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Review of Phases 1-2 found no correctness defects in the existing captures, but four coverage/robustness gaps: 1. LogUp accumulated-constraint branch coverage. Verified via real bus_interactions().len() (not call-site grep counts) that CPU (20) and EQ (6) are BOTH even, so split_interactions gives both absorbed_count==2 — neither exercises the 1-absorbed branch, and full_table_gate's comment claiming otherwise was wrong. No in-repo production table has an odd interaction count > 1. Added self-certifying targeted unit tests in crypto/stark/src/lookup.rs (logup_capture_tests module) that directly construct LookupAccumulatedConstraint with 1- and 2-element absorbed vecs, asserting absorbed.len()/degree() up front so the test can't silently degrade to the wrong branch if edited later. Also added a targeted LookupBatchedTermConstraint test. 2. Packing variant coverage. Only 5/10 Packing variants are ever instantiated by production tables, and DWordHHW/DWordBL had no differential gate at all (e2e-only). Added a test driving capture_fingerprint vs compute_fingerprint_from_step for all 10 variants over 1000 random rows each. 3. Fixed the inaccurate comment in full_table_gate (prover crate) claiming CPU+EQ cover both 1-/2-absorbed branches; corrected to state both are 2-absorbed and point to the new lookup.rs tests for the real coverage. 4. Fixed `cargo test -p stark --features constraint-ir` panicking: the examples/* and test-only AIRs (not part of the IR migration) rely on the default TransitionConstraintEvaluator::capture, which used to call unimplemented!(). Changed the default to call a new IrBuilder::mark_unsupported() instead (no panic), which marks the resulting ConstraintProgram::complete = false. ConstraintEvaluator::new and the verifier's step_2_compute_transitions now check `complete` and cache/use `None` (always fall back to the boxed path) for any AIR whose constraints aren't fully Capture-capable — production lambda_vm AIRs are unaffected (added an assertion to full_table_gate confirming prog.complete == true for CPU/EQ, plus a regression test in transition_tests.rs pinning the default-capture-doesn't-panic behavior). All new tests pass: cargo test -p stark --features constraint-ir (135/135, no panics), cargo test -p stark logup_capture_tests (4/4), cargo test -p lambda-vm-prover --release constraint_ir_tests (6/6) and the full --release suite with the feature on (430/430, unchanged from baseline). fmt/clippy clean on both feature states (only pre-existing, untouched prover_tests.rs warnings remain). Existing committed capture logic is untouched except the one comment fix. --- crypto/stark/src/constraint_ir/builder.rs | 17 + crypto/stark/src/constraint_ir/ir.rs | 9 + crypto/stark/src/constraints/evaluator.rs | 40 ++- crypto/stark/src/constraints/transition.rs | 20 +- crypto/stark/src/lookup.rs | 390 +++++++++++++++++++++ crypto/stark/src/tests/transition_tests.rs | 29 ++ crypto/stark/src/verifier.rs | 6 + prover/src/tests/constraint_ir_tests.rs | 22 +- 8 files changed, 508 insertions(+), 25 deletions(-) diff --git a/crypto/stark/src/constraint_ir/builder.rs b/crypto/stark/src/constraint_ir/builder.rs index 9ed7da4e6..5318615fd 100644 --- a/crypto/stark/src/constraint_ir/builder.rs +++ b/crypto/stark/src/constraint_ir/builder.rs @@ -46,6 +46,13 @@ pub struct IrBuilder { cse: HashMap<(Op, Dim), u32>, const_cache: HashMap, roots: Vec, + /// Set by [`Self::mark_unsupported`] when a constraint couldn't be + /// captured (the default `TransitionConstraintEvaluator::capture` body + /// calls this instead of panicking). Propagated to + /// [`ConstraintProgram::complete`] so callers can fall back to the boxed + /// evaluator for AIRs that aren't fully capture-capable (e.g. the + /// `examples/` and test-only AIRs, not part of the IR migration). + complete: bool, } impl Default for IrBuilder { @@ -63,6 +70,7 @@ impl IrBuilder { cse: HashMap::new(), const_cache: HashMap::new(), roots: Vec::new(), + complete: true, }; // Reserve id 0 = Const1(0). `const_base(0)` will hash-cons to this. let zero = b.push(Op::Const1(0), Dim::D1); @@ -71,6 +79,14 @@ impl IrBuilder { b } + /// Record that the constraint currently being captured has no `Capture` + /// implementation. Does not panic and does not emit a root for it — the + /// resulting program is marked incomplete (see [`ConstraintProgram::complete`]) + /// so callers know not to interpret it. + pub fn mark_unsupported(&mut self) { + self.complete = false; + } + /// Append (or reuse) a node with the given op and result dimension. fn push(&mut self, op: Op, dim: Dim) -> Expr { if let Some(&id) = self.cse.get(&(op, dim)) { @@ -238,6 +254,7 @@ impl IrBuilder { dims: self.dims, roots: self.roots, num_base, + complete: self.complete, } } } diff --git a/crypto/stark/src/constraint_ir/ir.rs b/crypto/stark/src/constraint_ir/ir.rs index 5b6603fba..ea7418c80 100644 --- a/crypto/stark/src/constraint_ir/ir.rs +++ b/crypto/stark/src/constraint_ir/ir.rs @@ -79,6 +79,15 @@ pub struct ConstraintProgram { /// interpreter writes these into `base_evals`; the rest (LogUp, always /// `D3`) go into `ext_evals[num_base..]`. pub num_base: usize, + /// `false` if any constraint in this program was captured via the + /// default `TransitionConstraintEvaluator::capture` (i.e. it has no real + /// `Capture` impl — see [`crate::constraint_ir::builder::IrBuilder::mark_unsupported`]). + /// Callers (the prover/verifier bridge) must not interpret an incomplete + /// program — fall back to the boxed `TransitionConstraintEvaluator` path + /// instead. Every production lambda_vm AIR captures cleanly + /// (`complete: true`); this only trips for the `examples/`/test-only + /// AIRs that predate the IR migration. + pub complete: bool, } impl ConstraintProgram { diff --git a/crypto/stark/src/constraints/evaluator.rs b/crypto/stark/src/constraints/evaluator.rs index 264c8abd8..a41b90c1e 100644 --- a/crypto/stark/src/constraints/evaluator.rs +++ b/crypto/stark/src/constraints/evaluator.rs @@ -23,9 +23,11 @@ pub struct ConstraintEvaluator< logup_table_offset: FieldElement, /// Captured once per proof (behind the `constraint-ir` feature): the flat /// IR program for every transition constraint, interpreted in place of - /// the boxed dispatch loop. See `crate::constraint_ir`. + /// the boxed dispatch loop. `None` if the AIR's constraints aren't all + /// `Capture`-capable (see `ConstraintProgram::complete`) — falls back to + /// the boxed path unconditionally in that case. See `crate::constraint_ir`. #[cfg(feature = "constraint-ir")] - constraint_program: crate::constraint_ir::ConstraintProgram, + constraint_program: Option, phantom: PhantomData<(Field, PI)>, } impl ConstraintEvaluator @@ -50,8 +52,9 @@ where num_periodic: usize, offsets: &[usize], logup_table_offset: &FieldElement, - #[cfg(feature = "constraint-ir")] - constraint_program: &crate::constraint_ir::ConstraintProgram, + #[cfg(feature = "constraint-ir")] constraint_program: Option< + &crate::constraint_ir::ConstraintProgram, + >, ) -> Vec> { let is_uniform = zerofier_data.is_uniform(); let num_base = air.num_base_transition_constraints(); @@ -106,12 +109,14 @@ where ); #[cfg(feature = "constraint-ir")] { - let ran = crate::constraint_ir::bridge::try_eval_program_prover( - constraint_program, - &ctx, - base_buf, - transition_buf, - ); + let ran = constraint_program.is_some_and(|prog| { + crate::constraint_ir::bridge::try_eval_program_prover( + prog, + &ctx, + base_buf, + transition_buf, + ) + }); if !ran { air.compute_transition_prover(&ctx, base_buf, transition_buf); } @@ -226,11 +231,22 @@ where None => FieldElement::zero(), }; + // `complete: false` means some constraint had no real `Capture` impl + // (e.g. an `examples/`/test-only AIR predating the IR migration) — + // don't cache an unusable program; `evaluate_transitions` then always + // takes the boxed path for this AIR. Every production lambda_vm AIR + // captures cleanly. + #[cfg(feature = "constraint-ir")] + let constraint_program = { + let prog = air.constraint_program(); + prog.complete.then_some(prog) + }; + Self { boundary_constraints, logup_table_offset, #[cfg(feature = "constraint-ir")] - constraint_program: air.constraint_program(), + constraint_program, phantom: PhantomData::<(Field, PI)> {}, } } @@ -336,7 +352,7 @@ where offsets, &self.logup_table_offset, #[cfg(feature = "constraint-ir")] - &self.constraint_program, + self.constraint_program.as_ref(), ) } } diff --git a/crypto/stark/src/constraints/transition.rs b/crypto/stark/src/constraints/transition.rs index 4710d6d08..3526791cc 100644 --- a/crypto/stark/src/constraints/transition.rs +++ b/crypto/stark/src/constraints/transition.rs @@ -28,17 +28,17 @@ where /// (non-boxed) counterpart that [`super::transition::TransitionConstraintAdapter`] /// forwards to. /// - /// Default panics: every production constraint must override this (via + /// Default marks the program incomplete (via + /// [`IrBuilder::mark_unsupported`]) rather than panicking: every + /// production lambda_vm constraint overrides this (via /// `TransitionConstraintAdapter` + `Capture`, or directly for the LogUp - /// framework constraints). The default exists only so the many - /// `examples/` and test-only `TransitionConstraintEvaluator` impls (not - /// part of the IR migration) don't need a body. - fn capture(&self, _builder: &mut IrBuilder) { - unimplemented!( - "TransitionConstraintEvaluator::capture not implemented for this constraint; \ - it is not part of the constraint-ir migration (see crypto/stark/src/examples/ \ - or implement Capture for production constraints)" - ); + /// framework constraints), but the many `examples/` and test-only + /// `TransitionConstraintEvaluator` impls (not part of the IR migration) + /// don't need a body — `AIR::constraint_program()` callers must check + /// `ConstraintProgram::complete` and fall back to the boxed evaluator + /// when it's `false`, rather than interpreting a partial program. + fn capture(&self, builder: &mut IrBuilder) { + builder.mark_unsupported(); } /// The function representing the evaluation of the constraint over elements diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index dbcceefef..5b03bf1e6 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -2398,3 +2398,393 @@ where b.emit(self.constraint_idx, root); } } + +#[cfg(test)] +mod logup_capture_tests { + //! Differential tests for `LookupAccumulatedConstraint::capture`, targeting + //! the **2-absorbed** branch specifically. + //! + //! `full_table_gate` (in the `lambda-vm-prover` crate) captures the real + //! CPU and EQ tables, but both happen to have an even `bus_interactions()` + //! count > 2 that nonetheless lands... no — both CPU (20 interactions) and + //! EQ (6 interactions) are even, so `split_interactions` gives them + //! `absorbed_count == 2` as well (verified directly via + //! `bus_interactions().len()`, not the call-site count). There is in fact + //! no in-repo production table with an *odd* interaction count > 1 that + //! would exercise `absorbed_count == 1` through `AirWithBuses` — every + //! real table happens to be even. This module exists to (a) pin down the + //! 2-absorbed branch with a self-certifying targeted test, independent of + //! which production tables happen to hit it, and (b) provide the same + //! coverage for 1-absorbed, since none of the current tables exercise it + //! through the full pipeline. + use super::*; + use crate::constraint_ir::{eval_program, eval_program_verifier}; + use crate::frame::Frame; + use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField as Ext3; + use math::field::goldilocks::GoldilocksField as Gl; + + type Fp = FieldElement; + type Fp3 = FieldElement; + + /// A tiny deterministic SplitMix64 PRNG (same generator as + /// `prover/src/tests/constraint_ir_tests.rs`) so this test needs no `rand` + /// dependency. + struct SplitMix64 { + state: u64, + } + impl SplitMix64 { + fn new(seed: u64) -> Self { + Self { state: seed } + } + fn next_u64(&mut self) -> u64 { + self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = self.state; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) + } + } + + const NUM_MAIN_COLS: usize = 4; + const NUM_TERM_COLUMNS: usize = 2; + const NUM_AUX_COLS: usize = NUM_TERM_COLUMNS + 1; + const TRIALS: usize = 1000; + + /// Two absorbed interactions: one sender (`Packing::Direct` on a main + /// column), one receiver (`Multiplicity::Column`, `BusValue::column`), + /// distinct `bus_id`s so the sign/bus-id handling can't accidentally + /// cancel out a translation bug. + fn two_absorbed_interactions() -> Vec { + vec![ + BusInteraction::sender( + 7u64, + Multiplicity::Column(0), + vec![BusValue::Packed { + start_column: 1, + packing: Packing::Direct, + }], + ), + BusInteraction::receiver(11u64, Multiplicity::Column(2), vec![BusValue::column(3)]), + ] + } + + /// One absorbed interaction (the 1-absorbed branch), for contrast. + fn one_absorbed_interaction() -> Vec { + vec![BusInteraction::sender( + 7u64, + Multiplicity::Column(0), + vec![BusValue::Packed { + start_column: 1, + packing: Packing::Direct, + }], + )] + } + + /// Build a random 2-step frame (offsets 0 and 1) with `NUM_MAIN_COLS` main + /// columns and `NUM_AUX_COLS` aux columns, plus matching rap challenges / + /// alpha powers / table offset / packing shifts. + #[allow(clippy::type_complexity)] + fn random_inputs( + seed: u64, + ) -> ( + Frame, + Vec, // rap_challenges (z, alpha) + Vec, // logup_alpha_powers + Fp3, // logup_table_offset + PackingShifts, + ) { + let mut rng = SplitMix64::new(seed); + + fn rand_fp3(rng: &mut SplitMix64) -> Fp3 { + FieldElement::::new([ + Fp::from(rng.next_u64()), + Fp::from(rng.next_u64()), + Fp::from(rng.next_u64()), + ]) + } + + fn step(rng: &mut SplitMix64) -> TableView { + let main: Vec = (0..NUM_MAIN_COLS) + .map(|_| Fp::from(rng.next_u64())) + .collect(); + let aux: Vec = (0..NUM_AUX_COLS).map(|_| rand_fp3(rng)).collect(); + TableView::new(vec![main], vec![aux]) + } + + let frame = Frame::new(vec![step(&mut rng), step(&mut rng)]); + let rap_challenges = vec![rand_fp3(&mut rng), rand_fp3(&mut rng)]; // [z, alpha] (alpha unused by fingerprint capture here) + // alpha_powers must cover every alpha index the captured interactions + // read; `Direct`/`column` each consume exactly 1, and `capture_fingerprint` + // starts at alpha_idx 1 (alpha^0 is the bus_id), so index 1 suffices for + // each interaction captured independently — give a generous buffer. + let logup_alpha_powers: Vec = (0..8).map(|_| rand_fp3(&mut rng)).collect(); + let logup_table_offset = rand_fp3(&mut rng); + let packing_shifts = PackingShifts::::new(); + + ( + frame, + rap_challenges, + logup_alpha_powers, + logup_table_offset, + packing_shifts, + ) + } + + /// Run the differential check for a given `absorbed` set: capture `c` via + /// `IrBuilder`, then for `TRIALS` random frames compare the boxed + /// `evaluate_verifier` (used for both the prover's `ext_evals` write and + /// the verifier path) against `eval_program`/`eval_program_verifier`, + /// bit-for-bit. + fn assert_ir_matches_evaluate(absorbed: Vec, label: &str) { + let c = LookupAccumulatedConstraint::new(0, NUM_TERM_COLUMNS, absorbed); + + let mut b = IrBuilder::new(); + TransitionConstraintEvaluator::::capture(&c, &mut b); + let prog = b.finish(0); // num_base = 0: this constraint is always D3-rooted + + assert!( + prog.complete, + "[{label}] capture must not fall back to the default unsupported marker" + ); + + for trial in 0..TRIALS { + let (frame, rap_challenges, logup_alpha_powers, logup_table_offset, packing_shifts) = + random_inputs(0xC0FF_EE00_u64.wrapping_add(trial as u64)); + + // --- Prover-side oracle vs eval_program --- + let prover_ctx = TransitionEvaluationContext::new_prover( + &frame, + &[], // no periodic columns + &rap_challenges, + &logup_alpha_powers, + &logup_table_offset, + &packing_shifts, + ); + let mut oracle_ext = vec![FieldElement::::zero(); 1]; + c.evaluate_verifier(&prover_ctx, &mut oracle_ext); + + let mut ir_base: Vec = vec![]; + let mut ir_ext = vec![FieldElement::::zero(); 1]; + eval_program(&prog, &prover_ctx, &mut ir_base, &mut ir_ext); + + assert_eq!( + oracle_ext, ir_ext, + "[{label}] prover mismatch at trial {trial}" + ); + + // --- Verifier-side oracle vs eval_program_verifier --- + // The verifier frame holds only extension-field elements; embed + // the same step data into Ext3 (mirrors `full_table_gate`). + let embed_step = |step: &TableView| -> TableView { + let main: Vec = (0..NUM_MAIN_COLS) + .map(|c| (*step.get_main_evaluation_element(0, c)).to_extension()) + .collect(); + let aux: Vec = (0..NUM_AUX_COLS) + .map(|c| *step.get_aux_evaluation_element(0, c)) + .collect(); + TableView::new(vec![main], vec![aux]) + }; + let verifier_frame: Frame = Frame::new(vec![ + embed_step(frame.get_evaluation_step(0)), + embed_step(frame.get_evaluation_step(1)), + ]); + let verifier_packing_shifts = PackingShifts::::new(); + let verifier_ctx = TransitionEvaluationContext::new_verifier( + &verifier_frame, + &[], + &rap_challenges, + &logup_alpha_powers, + &logup_table_offset, + &verifier_packing_shifts, + ); + let mut oracle_verifier = vec![FieldElement::::zero(); 1]; + c.evaluate_verifier(&verifier_ctx, &mut oracle_verifier); + + let mut ir_verifier = vec![FieldElement::::zero(); 1]; + eval_program_verifier(&prog, &verifier_ctx, &mut ir_verifier); + + assert_eq!( + oracle_verifier, ir_verifier, + "[{label}] verifier mismatch at trial {trial}" + ); + } + } + + #[test] + fn test_ir_matches_accumulated_constraint_two_absorbed() { + let absorbed = two_absorbed_interactions(); + // Self-certify: this test targets the 2-absorbed branch specifically. + // If a future edit shrinks `two_absorbed_interactions` to 1 element, + // this assertion fails loudly instead of silently degrading to the + // 1-absorbed branch. + assert_eq!( + absorbed.len(), + 2, + "this test must exercise absorbed.len()==2" + ); + let c = LookupAccumulatedConstraint::new(0, NUM_TERM_COLUMNS, absorbed.clone()); + assert_eq!( + TransitionConstraintEvaluator::::degree(&c), + 3, + "absorbed.len()==2 must select the degree-3 (f1*f2) branch" + ); + assert_ir_matches_evaluate(absorbed, "accumulated_two_absorbed"); + } + + #[test] + fn test_ir_matches_accumulated_constraint_one_absorbed() { + let absorbed = one_absorbed_interaction(); + assert_eq!( + absorbed.len(), + 1, + "this test must exercise absorbed.len()==1" + ); + let c = LookupAccumulatedConstraint::new(0, NUM_TERM_COLUMNS, absorbed.clone()); + assert_eq!( + TransitionConstraintEvaluator::::degree(&c), + 2, + "absorbed.len()==1 must select the degree-2 (f only) branch" + ); + assert_ir_matches_evaluate(absorbed, "accumulated_one_absorbed"); + } + + #[test] + fn test_ir_matches_batched_term_constraint() { + // LookupBatchedTermConstraint: always exactly 2 interactions (a, b), + // committed to one aux "term" column (column 0 here). Covered + // end-to-end by `full_table_gate`'s CPU/EQ programs already, but a + // targeted test pins the formula independent of table layout. + let interaction_a = BusInteraction::sender( + 7u64, + Multiplicity::Column(0), + vec![BusValue::Packed { + start_column: 1, + packing: Packing::Direct, + }], + ); + let interaction_b = + BusInteraction::receiver(11u64, Multiplicity::Column(2), vec![BusValue::column(3)]); + let c = LookupBatchedTermConstraint::new(interaction_a, interaction_b, 0, 0); + + let mut b = IrBuilder::new(); + TransitionConstraintEvaluator::::capture(&c, &mut b); + let prog = b.finish(0); + assert!(prog.complete); + + for trial in 0..TRIALS { + let (frame, rap_challenges, logup_alpha_powers, logup_table_offset, packing_shifts) = + random_inputs(0xFACE_F00D_u64.wrapping_add(trial as u64)); + + let prover_ctx = TransitionEvaluationContext::new_prover( + &frame, + &[], + &rap_challenges, + &logup_alpha_powers, + &logup_table_offset, + &packing_shifts, + ); + let mut oracle_ext = vec![FieldElement::::zero(); 1]; + c.evaluate_verifier(&prover_ctx, &mut oracle_ext); + + let mut ir_base: Vec = vec![]; + let mut ir_ext = vec![FieldElement::::zero(); 1]; + eval_program(&prog, &prover_ctx, &mut ir_base, &mut ir_ext); + + assert_eq!( + oracle_ext, ir_ext, + "batched_term prover mismatch at trial {trial}" + ); + } + } + + /// Differential coverage for **every** `Packing` variant's fingerprint + /// contribution: `capture_fingerprint` (IR) vs `compute_fingerprint_from_step` + /// (boxed), bit-for-bit on random rows. Only 5 of 10 variants are ever + /// instantiated by production tables (Direct, DWordWL, DWordHL, DWordBL, + /// DWordHHW); the live ones besides DWordHHW/DWordBL have no other + /// differential gate (only e2e coverage via `full_table_gate`), and the + /// other 5 (Word2L, Word4L, DWordWHH, QuadHL, QuadWL) have none at all. + /// This test drives `capture_packing_fingerprint`'s full `match` arm-by-arm. + #[test] + fn test_capture_fingerprint_matches_for_all_packing_variants() { + const ALL_PACKINGS: [Packing; 10] = [ + Packing::Direct, + Packing::Word2L, + Packing::Word4L, + Packing::DWordWL, + Packing::DWordHHW, + Packing::DWordWHH, + Packing::DWordHL, + Packing::DWordBL, + Packing::QuadHL, + Packing::QuadWL, + ]; + + // Enough main columns for the widest packing (DWordBL/QuadHL: 8). + const NUM_COLS: usize = 8; + + for packing in ALL_PACKINGS { + let interaction = BusInteraction::sender( + 13u64, + Multiplicity::One, + vec![BusValue::Packed { + start_column: 0, + packing, + }], + ); + + // Capture once per packing. + let mut b = IrBuilder::new(); + let root = capture_fingerprint(&mut b, &interaction, 0); + b.emit(0, root); + let prog = b.finish(0); + assert!(prog.complete); + + let mut rng = SplitMix64::new(0x9E11_u64.wrapping_add(packing.num_columns() as u64)); + for trial in 0..TRIALS { + let main: Vec = (0..NUM_COLS).map(|_| Fp::from(rng.next_u64())).collect(); + let step: TableView = + TableView::new(vec![main.clone()], vec![Vec::new()]); + + let z = FieldElement::::new([ + Fp::from(rng.next_u64()), + Fp::from(rng.next_u64()), + Fp::from(rng.next_u64()), + ]); + let alpha_powers: Vec = (0..6) + .map(|_| { + FieldElement::::new([ + Fp::from(rng.next_u64()), + Fp::from(rng.next_u64()), + Fp::from(rng.next_u64()), + ]) + }) + .collect(); + let shifts = PackingShifts::::new(); + + let oracle = + compute_fingerprint_from_step(&step, &interaction, &z, &alpha_powers, &shifts); + + // Drive `eval_program`'s shared `run()` walk directly (no full + // TransitionEvaluationContext needed: the captured program's + // only leaves are `Var{main}`, `RapChallenge{0}`, `AlphaPow{idx}`). + let logup_table_offset = FieldElement::::zero(); + let rap_challenges = vec![z]; + let frame = Frame::new(vec![step]); + let ctx = TransitionEvaluationContext::new_prover( + &frame, + &[], + &rap_challenges, + &alpha_powers, + &logup_table_offset, + &shifts, + ); + let mut ir_base: Vec = vec![]; + let mut ir_ext = vec![FieldElement::::zero(); 1]; + eval_program(&prog, &ctx, &mut ir_base, &mut ir_ext); + + assert_eq!(oracle, ir_ext[0], "{packing:?} mismatch at trial {trial}"); + } + } + } +} diff --git a/crypto/stark/src/tests/transition_tests.rs b/crypto/stark/src/tests/transition_tests.rs index 17bfaa6cc..5b78c9adb 100644 --- a/crypto/stark/src/tests/transition_tests.rs +++ b/crypto/stark/src/tests/transition_tests.rs @@ -83,3 +83,32 @@ fn end_exemptions_roots_zero_exemptions_is_empty() { assert!(c.end_exemptions_roots(&g, trace_length).is_empty()); } + +/// `DummyConstraint` doesn't override `capture`, so it exercises the default +/// `TransitionConstraintEvaluator::capture` body — which must not panic (see +/// `crypto/stark/src/constraints/transition.rs`) and must mark the resulting +/// `ConstraintProgram` incomplete via `IrBuilder::mark_unsupported`, so +/// `ConstraintEvaluator`/the verifier fall back to the boxed path instead of +/// interpreting a partial program. This is the regression test for the +/// `cargo test -p stark --features constraint-ir` panic fixed alongside this +/// test (every `examples/`/test-only AIR relies on this default). +#[test] +fn default_capture_marks_program_incomplete_without_panicking() { + use crate::constraint_ir::IrBuilder; + + let c = DummyConstraint:: { + period: 1, + offset: 0, + end_exemptions: 0, + phantom: PhantomData, + }; + + let mut b = IrBuilder::new(); + c.capture(&mut b); // must not panic + let prog = b.finish(0); + + assert!( + !prog.complete, + "a constraint with no Capture impl must mark the program incomplete" + ); +} diff --git a/crypto/stark/src/verifier.rs b/crypto/stark/src/verifier.rs index 705eb4811..4c05556d6 100644 --- a/crypto/stark/src/verifier.rs +++ b/crypto/stark/src/verifier.rs @@ -257,6 +257,12 @@ pub trait IsStarkVerifier< ctx: &TransitionEvaluationContext, ) -> Vec> { let prog = air.constraint_program(); + // `complete: false` means some constraint had no real `Capture` impl + // (e.g. an `examples/`/test-only AIR) — don't interpret a partial + // program, fall back to the boxed path unconditionally. + if !prog.complete { + return air.compute_transition(ctx); + } let mut evals = vec![FieldElement::::zero(); air.num_transition_constraints()]; let ran = crate::constraint_ir::bridge::try_eval_program_verifier(&prog, ctx, &mut evals); diff --git a/prover/src/tests/constraint_ir_tests.rs b/prover/src/tests/constraint_ir_tests.rs index 226f67412..d72f71ffd 100644 --- a/prover/src/tests/constraint_ir_tests.rs +++ b/prover/src/tests/constraint_ir_tests.rs @@ -190,9 +190,19 @@ mod full_table_gate { } /// A handful of real EQ operations (BEQ-style and BNE-style) so the - /// captured LogUp program (1 batched pair + 1 absorbed interaction, the - /// branch not exercised by the CPU table below) sees non-trivial - /// fingerprints on every row. + /// captured LogUp program sees non-trivial fingerprints on every row. + /// + /// Both this table and CPU (below) have an even `bus_interactions().len()` + /// (EQ: 6, CPU: 20), so `split_interactions` gives both an `absorbed_count` + /// of **2**, not 1 — there is in fact no in-repo production table whose + /// real interaction count is odd-and->1, so nothing here exercises + /// `LookupAccumulatedConstraint`'s 1-absorbed branch end-to-end. Both the + /// 1- and 2-absorbed branches (plus `LookupBatchedTermConstraint`, and all + /// 10 `Packing` variants) are covered by targeted, self-certifying + /// differential tests in `crypto/stark/src/lookup.rs`'s + /// `logup_capture_tests` module instead (each asserts `absorbed.len()`/ + /// `degree()` up front so the test can't silently degrade to the wrong + /// branch). fn sample_eq_operations() -> Vec { vec![ EqOperation::new(42, 42, false), @@ -262,6 +272,12 @@ mod full_table_gate { "every constraint_idx must have been emitted" ); assert_eq!(prog.num_base, air.num_base_transition_constraints()); + assert!( + prog.complete, + "[{label}] every production constraint must have a real Capture impl \ + (a constraint fell back to the default IrBuilder::mark_unsupported, \ + which would make ConstraintEvaluator skip the IR path entirely for this AIR)" + ); let num_base = air.num_base_transition_constraints(); let num_transition = air.num_transition_constraints();