diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index f00985d39..d53f06f10 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -168,6 +168,19 @@ where }) } + /// Create a root only Merkle tree placeholder: stores the commitment root + /// but no nodes. Used when paths are gathered from a device resident copy + /// (GPU) instead of this host tree, so the host nodes are never built. + /// [`get_proof_by_pos`](Self::get_proof_by_pos) must NOT be called on it. + pub fn from_root(root: B::Node) -> Self { + MerkleTree { + root, + nodes: Vec::new(), + #[cfg(feature = "disk-spill")] + mmap_backing: None, + } + } + /// Create a Merkle tree from pre-hashed leaf nodes. /// /// This skips the `hash_leaves` step, useful when leaves have already been diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu index 557b8dd43..e7bb8a618 100644 --- a/crypto/math-cuda/kernels/keccak.cu +++ b/crypto/math-cuda/kernels/keccak.cu @@ -393,6 +393,40 @@ extern "C" __global__ void keccak_merkle_level( finalize_keccak256(st, rate_pos, nodes + (parent_begin + tid) * 32); } +// Gather Merkle authentication paths for a batch of leaf positions, reading the +// resident tree `nodes` (32-byte nodes; layout: inner nodes [0..leaves_len-1], +// root at 0, leaves at [leaves_len-1..]). One thread per query walks leaf->root, +// writing each sibling node into the output. This mirrors the CPU +// `build_merkle_path` exactly (sibling_index / parent_index in +// crypto/crypto/src/merkle_tree/utils.rs): +// leaf node = pos + leaves_len - 1 +// sibling = node even ? node-1 : node+1 +// parent = node even ? (node-1)/2 : node/2 +// so `out[(q*depth + level)*32 .. +32]` is the level-th sibling for query q. +extern "C" __global__ void merkle_gather_paths( + const uint8_t *nodes, + const uint32_t *positions, // leaf positions, length num_queries + uint32_t num_queries, + uint64_t leaves_len, + uint32_t depth, // = log2(leaves_len) + uint8_t *out) { // num_queries * depth * 32 bytes + uint32_t q = blockIdx.x * blockDim.x + threadIdx.x; + if (q >= num_queries) return; + + uint64_t node = (uint64_t)positions[q] + leaves_len - 1; + for (uint32_t level = 0; level < depth; ++level) { + uint64_t sib = (node & 1ull) ? (node + 1ull) : (node - 1ull); + // 32-byte nodes at 32-byte-aligned offsets (cuMemAlloc 256-aligned), + // so the u64 copy is safe. + const uint64_t *src = reinterpret_cast(nodes + sib * 32); + uint64_t *dst = reinterpret_cast( + out + ((uint64_t)q * depth + level) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) dst[i] = src[i]; + node = (node & 1ull) ? (node >> 1) : ((node - 1ull) >> 1); + } +} + // --------------------------------------------------------------------------- // Row-major ROW-PAIR leaf hashing. // diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 4270e5da8..3a149a83f 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -118,6 +118,9 @@ pub struct Backend { pinned_hashes: Vec>, util_stream: Arc, next: AtomicUsize, + /// VRAM budget (bytes) for table-session admission control. See + /// [`detect_vram_budget_bytes`]. + vram_budget_bytes: u64, // arith.ptx pub vector_add_u64: CudaFunction, @@ -154,6 +157,7 @@ pub struct Backend { pub keccak_comp_poly_leaves_ext3: CudaFunction, pub keccak_fri_leaves_ext3: CudaFunction, pub keccak_merkle_level: CudaFunction, + pub merkle_gather_paths: CudaFunction, // barycentric.ptx pub barycentric_base_batched: CudaFunction, @@ -181,6 +185,74 @@ pub struct Backend { inv_twiddles: Mutex>>>>, } +/// Raise the device default memory pool's release threshold so freed +/// stream-ordered allocations are kept for reuse instead of returned to the OS +/// at each sync. Best-effort: any failure (e.g. a device/driver without +/// stream-ordered allocator support) leaves the default behaviour untouched. +fn retain_default_mempool(ctx: &CudaContext) { + use cudarc::driver::sys; + // SAFETY: raw CUDA driver calls. `ctx.cu_device()` is a valid device for + // the just-created context; the out-pointers are valid stack slots; the + // threshold is read as a u64 by the driver. Errors are swallowed. + unsafe { + let dev = ctx.cu_device(); + let mut pool: sys::CUmemoryPool = std::ptr::null_mut(); + if sys::cuDeviceGetDefaultMemPool(&mut pool as *mut _, dev) + .result() + .is_err() + { + return; + } + // Default: retain freed stream-ordered blocks indefinitely (u64::MAX) + // for reuse. `LAMBDA_VM_MEMPOOL_RELEASE_MB` overrides the cap (bytes the + // pool keeps before returning memory to the OS) when retained-pool + // growth needs bounding. + let threshold: u64 = std::env::var("LAMBDA_VM_MEMPOOL_RELEASE_MB") + .ok() + .and_then(|s| s.parse::().ok()) + .map(|mb| mb.saturating_mul(1024 * 1024)) + .unwrap_or(u64::MAX); + let _ = sys::cuMemPoolSetAttribute( + pool, + sys::CUmemPool_attribute_enum::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, + &threshold as *const u64 as *mut core::ffi::c_void, + ) + .result(); + } +} + +/// Device VRAM budget in bytes for table session admission control. +/// +/// LAMBDA_VM_VRAM_BUDGET_MB overrides it (used to force the throttle in tests). +/// Otherwise it is 80% of total device memory, leaving headroom for the +/// context, module code, and retained pool blocks. Returns u64::MAX on any +/// query failure, which disables budgeting (chunks fall back to the core bound +/// size alone). +fn detect_vram_budget_bytes(ctx: &CudaContext) -> u64 { + if let Ok(mb) = std::env::var("LAMBDA_VM_VRAM_BUDGET_MB") + && let Ok(mb) = mb.parse::() + { + return mb.saturating_mul(1024 * 1024); + } + use cudarc::driver::sys; + // SAFETY: raw driver query writing into two stack slots. The caller's + // context is already current (it was just created in `init`). Any error + // falls through to the budgeting-disabled sentinel. + unsafe { + let _ = ctx; + let mut free: usize = 0; + let mut total: usize = 0; + if sys::cuMemGetInfo_v2(&mut free as *mut usize, &mut total as *mut usize) + .result() + .is_err() + { + return u64::MAX; + } + // 80% of total, computed to avoid intermediate overflow. + (total as u64) / 5 * 4 + } +} + impl Backend { fn init() -> Result { let ctx = CudaContext::new(0)?; @@ -190,6 +262,17 @@ impl Backend { // before returning), so the tracking is pure overhead. Disable it. unsafe { ctx.disable_event_tracking() }; + // Retain freed device memory in the stream ordered pool for reuse. + // + // cudarc routes CudaStream::alloc* through cuMemAllocAsync, drawing from + // the device default memory pool. Its release threshold defaults to 0, + // so every freed buffer goes back to the OS at the next sync and the + // prover's large LDE/FRI buffers are rebuilt from scratch each op. + // Raising the threshold keeps freed blocks in the pool so a same size + // allocation skips a real driver allocation. Best effort: on any error + // we keep the current behaviour. + retain_default_mempool(&ctx); + let arith = ctx.load_module(Ptx::from_src(ARITH_PTX))?; let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?; let keccak = ctx.load_module(Ptx::from_src(KECCAK_PTX))?; @@ -225,6 +308,8 @@ impl Backend { // Length = TWO_ADICITY + 1 to allow indexing at log_n = TWO_ADICITY. let max_log = GoldilocksField::TWO_ADICITY as usize + 1; + let vram_budget_bytes = detect_vram_budget_bytes(&ctx); + Ok(Self { vector_add_u64: arith.load_function("vector_add_u64")?, gl_add: arith.load_function("gl_add_kernel")?, @@ -257,6 +342,7 @@ impl Backend { keccak_comp_poly_leaves_ext3: keccak.load_function("keccak_comp_poly_leaves_ext3")?, keccak_fri_leaves_ext3: keccak.load_function("keccak_fri_leaves_ext3")?, keccak_merkle_level: keccak.load_function("keccak_merkle_level")?, + merkle_gather_paths: keccak.load_function("merkle_gather_paths")?, barycentric_base_batched: bary.load_function("barycentric_base_batched")?, barycentric_ext3_batched: bary.load_function("barycentric_ext3_batched")?, barycentric_base_batched_strided: bary @@ -282,9 +368,16 @@ impl Backend { pinned_hashes, util_stream, next: AtomicUsize::new(0), + vram_budget_bytes, }) } + /// VRAM budget in bytes for table-session admission control. `u64::MAX` + /// when budgeting is disabled (query failed). See the field docs. + pub fn vram_budget_bytes(&self) -> u64 { + self.vram_budget_bytes + } + /// Round-robin over the stream pool. Concurrent callers get different /// streams so their kernel launches overlap on the GPU. pub fn next_stream(&self) -> Arc { diff --git a/crypto/math-cuda/src/fri.rs b/crypto/math-cuda/src/fri.rs index edd359b1b..a2f96c07a 100644 --- a/crypto/math-cuda/src/fri.rs +++ b/crypto/math-cuda/src/fri.rs @@ -98,7 +98,7 @@ impl FriCommitState { pub fn fold_and_commit_layer( &mut self, zeta_raw: [u64; 3], - ) -> Result<(Vec, Vec, Vec)> { + ) -> Result<(Vec, crate::lde::GpuMerkleTree)> { #[cfg(feature = "test-faults")] check_fault_injection()?; let be = backend()?; @@ -214,17 +214,22 @@ impl FriCommitState { self.stream.clone_dtoh(&view)? }; - // Tree nodes. - let nodes_bytes: Vec = self.stream.clone_dtoh(&nodes_dev)?; - debug_assert_eq!(nodes_bytes.len(), tight_total_nodes * 32); - - let mut root = vec![0u8; 32]; - root.copy_from_slice(&nodes_bytes[0..32]); + // Keep the layer tree resident on device; copy only the 32-byte root so + // R4 query openings gather paths on device instead of copying the tree. + let mut root = [0u8; 32]; + self.stream + .memcpy_dtoh(&nodes_dev.slice(0..32), &mut root)?; + self.stream.synchronize()?; self.a_is_input = !self.a_is_input; self.current_n = n_out; - Ok((root, layer_evals, nodes_bytes)) + let tree = crate::lde::GpuMerkleTree { + nodes: std::sync::Arc::new(nodes_dev), + leaves_len: num_leaves, + root, + }; + Ok((layer_evals, tree)) } /// Final fold, no Merkle commit. Returns the single ext3 output diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index 5ed58fe87..427d84351 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -397,7 +397,7 @@ fn coset_lde_row_major_inner( blowup_factor: usize, weights: &[u64], what: &str, -) -> Result<(Vec, CudaSlice, Vec)> { +) -> Result<(GpuMerkleTree, CudaSlice, Vec)> { assert_eq!(row_major.len(), n * total_cols); assert!(n.is_power_of_two()); assert_eq!(weights.len(), n); @@ -489,33 +489,44 @@ fn coset_lde_row_major_inner( out }; - let mut nodes_out = vec![0u8; nodes_bytes]; - d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, &mut nodes_out)?; + // Keep the Merkle tree resident on device; copy only the 32 byte root so the + // commitment is available without copying the whole tree. Query openings + // gather paths from the device tree (see merkle::gather_merkle_paths_dev). + let mut root = [0u8; 32]; + stream.memcpy_dtoh(&nodes_dev.slice(0..32), &mut root)?; - // Transpose row-major buf → column-major for the handle. Downstream kernels - // (DEEP, barycentric) expect buf[c * lde_size + r] (column-major). + // Transpose row-major buf into column-major for the handle. Downstream + // kernels (DEEP, barycentric) expect buf[c * lde_size + r] (column-major). let col_major_dev = launch_row_to_col_major(&stream, be, &buf, lde_size, total_cols, lde_u64)?; - // Synchronize before returning: the handle crosses stream boundaries — downstream - // consumers call be.next_stream() and read handle.buf on a different stream. - // Without this, a barycentric or DEEP kernel can start before the transpose finishes. + // Synchronize before returning: the handle crosses stream boundaries. + // Downstream consumers call be.next_stream() and read handle.buf on a + // different stream, and the root copy above must have landed. stream.synchronize()?; - Ok((nodes_out, col_major_dev, lde_out)) + let tree = GpuMerkleTree { + nodes: Arc::new(nodes_dev), + leaves_len: num_leaves, + root, + }; + Ok((tree, col_major_dev, lde_out)) } -/// Row-major LDE + Keccak + Merkle, all on-device. +/// Row-major LDE + Keccak + Merkle, all on-device, keeping the Merkle tree +/// resident on device (in the handle's `tree`). The host tree is not built, so +/// the whole tree copy to host is eliminated; query openings gather paths from +/// the device tree. /// -/// Input: `row_major` is a flat `n * m` slice in row-major order. -/// Returns (merkle_nodes, GpuLdeBase handle, row-major LDE Vec). -/// The returned handle is column-major (as required by downstream GPU kernels). +/// Input: `row_major` is a flat `n * m` slice in row-major order. Returns the +/// `GpuLdeBase` handle (column-major buf, plus the device tree) and the +/// row-major LDE Vec. pub fn coset_lde_row_major_with_merkle_tree_keep( row_major: &[u64], n: usize, m: usize, blowup_factor: usize, weights: &[u64], -) -> Result<(Vec, GpuLdeBase, Vec)> { - let (nodes_out, col_major_dev, lde_out) = coset_lde_row_major_inner( +) -> Result<(GpuLdeBase, Vec)> { + let (tree, col_major_dev, lde_out) = coset_lde_row_major_inner( row_major, n, m, @@ -527,8 +538,9 @@ pub fn coset_lde_row_major_with_merkle_tree_keep( buf: Arc::new(col_major_dev), m, lde_size: n * blowup_factor, + tree: Some(tree), }; - Ok((nodes_out, handle, lde_out)) + Ok((handle, lde_out)) } /// Row-major ext3 LDE + Keccak + Merkle, all on-device. @@ -547,8 +559,8 @@ pub fn coset_lde_ext3_row_major_with_merkle_tree_keep( m: usize, blowup_factor: usize, weights: &[u64], -) -> Result<(Vec, GpuLdeExt3, Vec)> { - let (nodes_out, col_major_dev, lde_out) = coset_lde_row_major_inner( +) -> Result<(GpuLdeExt3, Vec)> { + let (tree, col_major_dev, lde_out) = coset_lde_row_major_inner( row_major, n, m * 3, @@ -560,18 +572,24 @@ pub fn coset_lde_ext3_row_major_with_merkle_tree_keep( buf: Arc::new(col_major_dev), m, lde_size: n * blowup_factor, + tree: Some(tree), }; - Ok((nodes_out, handle, lde_out)) + Ok((handle, lde_out)) } /// Handle to a base-field LDE kept live on device after R1 commit. /// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset /// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. +/// +/// `tree` optionally carries the main trace Merkle tree kept resident on device +/// (the keep path), so R4 query openings gather paths on device instead of +/// copying the whole tree to host. None on the CPU path. #[derive(Clone)] pub struct GpuLdeBase { pub buf: Arc>, pub m: usize, pub lde_size: usize, + pub tree: Option, } /// Handle to an ext3 LDE kept live on device, de-interleaved into 3 base @@ -582,6 +600,23 @@ pub struct GpuLdeExt3 { pub buf: Arc>, pub m: usize, pub lde_size: usize, + /// Optionally the aux or composition Merkle tree kept resident on device + /// (the keep path), so R4 openings gather paths on device. None otherwise. + pub tree: Option, +} + +/// Merkle tree kept resident on device after a commit, so query openings gather +/// paths on device instead of copying the whole tree to host. Node layout +/// matches the CPU tree (`crypto/crypto/src/merkle_tree`): `nodes[0..leaves_len-1]` +/// are inner nodes (root at 0), `nodes[leaves_len-1..]` are the leaves, each 32 +/// bytes. Freed when the `nodes` Arc drops. +#[derive(Clone)] +pub struct GpuMerkleTree { + pub nodes: Arc>, + pub leaves_len: usize, + /// The Merkle root (node 0), copied to host at build time so the commitment + /// is available without copying the whole tree. + pub root: [u8; 32], } pub fn coset_lde_base(evals: &[u64], blowup_factor: usize, weights: &[u64]) -> Result> { @@ -1141,6 +1176,7 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner( buf: Arc::new(buf), m, lde_size, + tree: None, })) } else { drop(buf); @@ -1331,6 +1367,7 @@ fn evaluate_poly_coset_batch_ext3_into_inner( buf: std::sync::Arc::new(buf), m, lde_size, + tree: None, })) } else { drop(buf); diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs index 27f38ce0a..fb1125ea4 100644 --- a/crypto/math-cuda/src/merkle.rs +++ b/crypto/math-cuda/src/merkle.rs @@ -17,6 +17,7 @@ //! to match `FieldElement::::write_bytes_be`. use cudarc::driver::{CudaSlice, CudaStream, CudaViewMut, LaunchConfig, PushKernelArg}; +use std::sync::Arc; use crate::Result; use crate::device::{Backend, backend}; @@ -316,13 +317,75 @@ pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { Ok(out) } -/// Row-pair Keccak leaf + Merkle tree build for R2 composition-polynomial -/// commit. `parts_interleaved` is `num_parts` slices, each holding an ext3 -/// LDE column interleaved as `[a0,a1,a2, b0,b1,b2, ...]` of length `3*lde_size`. -/// -/// Returns `(2*(lde_size/2) - 1) * 32` bytes of tree nodes in the standard -/// layout (root at byte offset 0, leaves in the tail). -pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Result> { +/// Gather Merkle authentication paths on device for `positions` (leaf indices) +/// against the resident tree `nodes_dev` (standard layout, `2*leaves_len-1` +/// nodes of 32 bytes). Returns `positions.len() * depth * 32` bytes, where +/// `depth = log2(leaves_len)`. Query `q`'s path is `[q*depth*32 .. +/// (q+1)*depth*32]`, each 32 byte node a sibling from leaf to root. These are +/// the same nodes the CPU `MerkleTree::get_proof_by_pos` collects. Runs on the +/// caller's `stream` (pass the table's session stream). +pub fn gather_merkle_paths_dev( + nodes_dev: &CudaSlice, + leaves_len: usize, + positions: &[u32], + stream: &Arc, +) -> Result> { + let num_queries = positions.len(); + if num_queries == 0 { + return Ok(Vec::new()); + } + assert!( + leaves_len.is_power_of_two() && leaves_len >= 2, + "leaves_len must be a power of two >= 2" + ); + let depth = leaves_len.trailing_zeros() as usize; + // Guard the kernel's device reads: a position past leaves_len would walk + // off the node buffer. Positions are valid by construction; this catches a + // caller bug before it becomes an out of bounds device read. + assert!( + positions.iter().all(|&p| (p as usize) < leaves_len), + "gather_merkle_paths_dev: leaf position >= leaves_len" + ); + let be = backend()?; + + let pos_dev = stream.clone_htod(positions)?; + // SAFETY: every byte of `out` is written by the kernel below (one 32-byte + // node per (query, level)) before the D2H reads it back. + let mut out = unsafe { stream.alloc::(num_queries * depth * 32) }?; + + let grid = (num_queries as u32).div_ceil(KECCAK_BLOCK_DIM); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (KECCAK_BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + }; + let num_queries_u32 = num_queries as u32; + let leaves_len_u64 = leaves_len as u64; + let depth_u32 = depth as u32; + unsafe { + stream + .launch_builder(&be.merkle_gather_paths) + .arg(nodes_dev) + .arg(&pos_dev) + .arg(&num_queries_u32) + .arg(&leaves_len_u64) + .arg(&depth_u32) + .arg(&mut out) + .launch(cfg)?; + } + let host = stream.clone_dtoh(&out)?; + stream.synchronize()?; + Ok(host) +} + +/// Build the composition Merkle tree on device. `parts_interleaved` is +/// `num_parts` slices, each an ext3 LDE column interleaved as +/// `[a0,a1,a2, b0,b1,b2, ...]` of length `3*lde_size`. Leaves hash row pairs, so +/// `num_leaves = lde_size / 2`. Returns the device node buffer, the leaf count, +/// and the stream it was built on. Used by the device keep wrapper below. +fn build_comp_poly_tree_nodes_dev( + parts_interleaved: &[&[u64]], +) -> Result<(CudaSlice, usize, Arc)> { assert!(!parts_interleaved.is_empty()); let m = parts_interleaved.len(); let ext3_elems = parts_interleaved[0].len() / 3; @@ -351,9 +414,13 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res pack_ext3_to_pinned_slabs(parts_interleaved, pinned, lde_size); - // H2D the de-interleaved parts. + // H2D the de-interleaved parts, then release the staging lock (the kernels + // below read the device `buf`, not `pinned`). Synchronize first so the + // async H2D has consumed `pinned` before it is freed/reused. let mut buf = stream.alloc_zeros::(mb * lde_size)?; stream.memcpy_htod(&pinned[..mb * lde_size], &mut buf)?; + stream.synchronize()?; + drop(staging); // Leaves into tail of a tight node buffer. let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; @@ -380,18 +447,33 @@ pub fn build_comp_poly_tree_from_evals_ext3(parts_interleaved: &[&[u64]]) -> Res } build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?; + Ok((nodes_dev, num_leaves, stream)) +} - let out = stream.clone_dtoh(&nodes_dev)?; +/// Build the comp poly Merkle tree on device and keep the nodes resident +/// (returned as a [`crate::lde::GpuMerkleTree`] with its root), so R4 +/// composition openings gather paths on device instead of copying the whole +/// tree to host. `leaves_len = lde_size / 2` (row pair leaves). +pub fn build_comp_poly_tree_from_evals_ext3_keep( + parts_interleaved: &[&[u64]], +) -> Result { + let (nodes_dev, num_leaves, stream) = build_comp_poly_tree_nodes_dev(parts_interleaved)?; + let mut root = [0u8; 32]; + stream.memcpy_dtoh(&nodes_dev.slice(0..32), &mut root)?; stream.synchronize()?; - drop(staging); - Ok(out) + Ok(crate::lde::GpuMerkleTree { + nodes: Arc::new(nodes_dev), + leaves_len: num_leaves, + root, + }) } -/// Build a FRI-layer Merkle tree on device from an interleaved ext3 eval -/// vector. Each leaf hashes two consecutive ext3 values. `num_leaves = -/// evals.len() / 6` (since each ext3 is 3 u64s). -/// -/// Returns the `(2*num_leaves - 1) * 32`-byte node buffer in standard layout. +/// Test-only parity harness: build a FRI layer Merkle tree on device from an +/// interleaved ext3 eval vector and return the full host node buffer so tests +/// can compare it byte for byte against the CPU. Production folds and commits +/// via [`crate::fri::FriLayer::fold_and_commit_layer`]. Each leaf hashes two +/// consecutive ext3 values; `num_leaves = evals.len() / 6`. Returns the +/// `(2*num_leaves - 1) * 32`-byte node buffer in standard layout. pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { assert!( evals.len().is_multiple_of(6), diff --git a/crypto/math-cuda/tests/barycentric_strided.rs b/crypto/math-cuda/tests/barycentric_strided.rs index 653ef4e38..377a2b531 100644 --- a/crypto/math-cuda/tests/barycentric_strided.rs +++ b/crypto/math-cuda/tests/barycentric_strided.rs @@ -49,6 +49,7 @@ fn run_base(log_trace: u32, blowup: usize, num_cols: usize, seed: u64) { buf: Arc::new(lde_dev), m: num_cols, lde_size, + tree: None, }; // Pre-strided buffer for non-strided reference: trace-size picks of each col. @@ -105,6 +106,7 @@ fn run_ext3(log_trace: u32, blowup: usize, num_cols: usize, seed: u64) { buf: Arc::new(lde_dev), m: num_cols, lde_size, + tree: None, }; // Pre-strided buffer for non-strided reference. diff --git a/crypto/math-cuda/tests/deep.rs b/crypto/math-cuda/tests/deep.rs index 8499cd04a..6ab63be10 100644 --- a/crypto/math-cuda/tests/deep.rs +++ b/crypto/math-cuda/tests/deep.rs @@ -177,12 +177,14 @@ fn run_parity( buf: Arc::new(main_dev), m: num_main, lde_size, + tree: None, }; let aux_handle = if num_aux > 0 { Some(GpuLdeExt3 { buf: Arc::new(aux_dev), m: num_aux, lde_size, + tree: None, }) } else { drop(aux_dev); diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs index 61a861f32..087ccde14 100644 --- a/crypto/math-cuda/tests/keccak_leaves.rs +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -217,8 +217,13 @@ fn keccak_comp_poly_leaves_matches_cpu() { let parts_slices: Vec<&[u64]> = parts_interleaved.iter().map(|v| v.as_slice()).collect(); - let nodes = - math_cuda::merkle::build_comp_poly_tree_from_evals_ext3(&parts_slices).unwrap(); + // Exercise the production keep path, then read the resident nodes + // back to host to check the leaf bytes. + let tree = math_cuda::merkle::build_comp_poly_tree_from_evals_ext3_keep(&parts_slices) + .unwrap(); + let be = math_cuda::device::backend().unwrap(); + let stream = be.next_stream(); + let nodes: Vec = stream.clone_dtoh(&*tree.nodes).unwrap(); let num_leaves = lde_size / 2; let leaves_offset = (num_leaves - 1) * 32; for i in 0..num_leaves { diff --git a/crypto/math-cuda/tests/merkle_gather.rs b/crypto/math-cuda/tests/merkle_gather.rs new file mode 100644 index 000000000..36e05a719 --- /dev/null +++ b/crypto/math-cuda/tests/merkle_gather.rs @@ -0,0 +1,84 @@ +//! Parity: GPU `gather_merkle_paths_dev` must produce, for each leaf position, +//! the exact `merkle_path` the CPU `MerkleTree::get_proof_by_pos` returns: the +//! same sibling order from leaf to root, byte for byte. This is the gate for +//! gathering R4 query openings on device instead of copying the whole tree. + +use crypto::merkle_tree::backends::field_element_vector::FieldElementVectorBackend; +use crypto::merkle_tree::merkle::MerkleTree; +use math::field::goldilocks::GoldilocksField; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::Keccak256; + +type CpuTree = MerkleTree>; + +fn run_gather_parity(log_n: u32, seed: u64) { + let leaves_len = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let leaves: Vec<[u8; 32]> = (0..leaves_len) + .map(|_| { + let mut arr = [0u8; 32]; + rng.fill(&mut arr[..]); + arr + }) + .collect(); + + let mut flat = Vec::with_capacity(leaves_len * 32); + for l in &leaves { + flat.extend_from_slice(l); + } + + // Build the tree on device, then upload its nodes back as the resident + // buffer the gather reads (build_merkle_tree_on_device returns host bytes). + let gpu_nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(&flat).unwrap(); + + // CPU reference tree over the same backend as the prover. + let cpu_tree = CpuTree::build_from_hashed_leaves(leaves).unwrap(); + + // Query a spread of positions: first, last, and random interior ones. + let mut positions: Vec = vec![0, (leaves_len - 1) as u32]; + let mut r = ChaCha8Rng::seed_from_u64(seed ^ 0xabcd); + for _ in 0..16usize.min(leaves_len) { + positions.push(r.gen_range(0..leaves_len) as u32); + } + + let be = math_cuda::device::backend().unwrap(); + let stream = be.next_stream(); + let nodes_dev = stream.clone_htod(&gpu_nodes_bytes).unwrap(); + stream.synchronize().unwrap(); + + let depth = log_n as usize; + let paths = + math_cuda::merkle::gather_merkle_paths_dev(&nodes_dev, leaves_len, &positions, &stream) + .unwrap(); + assert_eq!(paths.len(), positions.len() * depth * 32); + + for (q, &pos) in positions.iter().enumerate() { + let cpu_proof = cpu_tree.get_proof_by_pos(pos as usize).unwrap(); + assert_eq!( + cpu_proof.merkle_path.len(), + depth, + "depth mismatch at log_n={log_n} pos={pos}" + ); + for (level, cpu_node) in cpu_proof.merkle_path.iter().enumerate() { + let g = &paths[(q * depth + level) * 32..(q * depth + level + 1) * 32]; + assert_eq!( + g, + &cpu_node[..], + "path node mismatch: log_n={log_n} pos={pos} level={level}" + ); + } + } +} + +#[test] +fn merkle_gather_small() { + for log_n in 1u32..=6 { + run_gather_parity(log_n, 200 + log_n as u64); + } +} + +#[test] +fn merkle_gather_large() { + run_gather_parity(18, 7777); +} diff --git a/crypto/math-cuda/tests/merkle_root_parity.rs b/crypto/math-cuda/tests/merkle_root_parity.rs index ee59d323b..fcc9d226e 100644 --- a/crypto/math-cuda/tests/merkle_root_parity.rs +++ b/crypto/math-cuda/tests/merkle_root_parity.rs @@ -299,17 +299,15 @@ fn new_row_major_pipeline_base_root_matches_cpu() { let fwd_tw = TwoHalfTwiddles::::new(log_lde, false).expect("fwd twiddles"); - let (nodes, _handle, _lde) = - math_cuda::lde::coset_lde_row_major_with_merkle_tree_keep( - &row_major, - n, - num_cols, - blowup, - &weights_u64, - ) - .expect("new row-major GPU pipeline"); - let mut gpu_root = [0u8; 32]; - gpu_root.copy_from_slice(&nodes[0..32]); + let (handle, _lde) = math_cuda::lde::coset_lde_row_major_with_merkle_tree_keep( + &row_major, + n, + num_cols, + blowup, + &weights_u64, + ) + .expect("new row-major GPU pipeline"); + let gpu_root = handle.tree.as_ref().expect("resident merkle tree").root; let cpu_root = cpu_row_major_merkle_root( &(0..num_cols) @@ -363,7 +361,7 @@ fn new_row_major_pipeline_ext3_root_matches_cpu() { let fwd_tw = TwoHalfTwiddles::::new(log_lde, false).expect("fwd twiddles"); - let (nodes, _handle, _lde) = + let (handle, _lde) = math_cuda::lde::coset_lde_ext3_row_major_with_merkle_tree_keep( &row_major, n, @@ -372,8 +370,7 @@ fn new_row_major_pipeline_ext3_root_matches_cpu() { &weights_u64, ) .expect("new ext3 row-major GPU pipeline"); - let mut gpu_root = [0u8; 32]; - gpu_root.copy_from_slice(&nodes[0..32]); + let gpu_root = handle.tree.as_ref().expect("resident merkle tree").root; let cpu_root = cpu_ext3_row_major_merkle_root(&columns, blowup, &weights_fp, &inv_tw, &fwd_tw); diff --git a/crypto/stark/src/fri/fri_commitment.rs b/crypto/stark/src/fri/fri_commitment.rs index 831471761..58c9eed77 100644 --- a/crypto/stark/src/fri/fri_commitment.rs +++ b/crypto/stark/src/fri/fri_commitment.rs @@ -13,6 +13,11 @@ where { pub evaluation: Vec>, pub merkle_tree: MerkleTree, + /// The layer's Merkle tree kept resident on device (GPU FRI commit path), + /// so R4 query openings gather authentication paths on device. When set, + /// `merkle_tree` is a root only placeholder. `None` on the CPU path. + #[cfg(feature = "cuda")] + pub gpu_tree: Option, } impl FriLayer @@ -25,6 +30,8 @@ where Self { evaluation: evaluation.to_vec(), merkle_tree, + #[cfg(feature = "cuda")] + gpu_tree: None, } } } diff --git a/crypto/stark/src/fri/mod.rs b/crypto/stark/src/fri/mod.rs index 60ad2a398..181c27380 100644 --- a/crypto/stark/src/fri/mod.rs +++ b/crypto/stark/src/fri/mod.rs @@ -117,6 +117,14 @@ pub fn query_phase( where FieldElement: AsBytes + Sync + Send, { + // GPU fast path: gather every layer's authentication paths on device (the + // layer trees stay resident from the GPU commit). Falls back to the host + // walk below if any layer lacks a device tree. + #[cfg(feature = "cuda")] + if let Some(decommits) = crate::gpu_lde::try_fri_query_phase_gpu::(fri_layers, iotas) { + return decommits; + } + if !fri_layers.is_empty() { let num_layers = fri_layers.len(); iotas diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 920bf937e..3f1d81846 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -16,6 +16,7 @@ use math_cuda::{CudaSlice, CudaStream}; use crypto::fiat_shamir::is_transcript::IsStarkTranscript; use crypto::merkle_tree::merkle::MerkleTree; +use crypto::merkle_tree::proof::Proof; use crypto::merkle_tree::traits::IsMerkleTreeBackend; use math::field::element::FieldElement; use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; @@ -23,9 +24,10 @@ use math::field::goldilocks::GoldilocksField; use math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; use math::traits::AsBytes; -use crate::config::FriLayerMerkleTreeBackend; +use crate::config::{Commitment, FriLayerMerkleTreeBackend}; use crate::domain::Domain; use crate::fri::fri_commitment::FriLayer; +use crate::fri::fri_decommit::FriDecommitment; use crate::fri::fri_functions::compute_coset_twiddles_inv; use crate::trace::LDETraceTable; @@ -430,7 +432,9 @@ pub fn gpu_leaf_hash_calls() -> u64 { } /// Row-major GPU path: single H2D → row-major NTT → row-major Keccak → -/// Merkle → single D2H. No column extraction or CPU-side transpose. +/// Merkle → single D2H. Keeps the Merkle tree resident on device (in the +/// handle's `.tree`); the returned host `MerkleTree` is root only, so query +/// openings gather paths from the device tree via [`gather_proofs_dev`]. pub(crate) fn try_expand_leaf_and_tree_row_major_keep( row_major: &[FieldElement], n: usize, @@ -468,7 +472,8 @@ where GPU_LEAF_HASH_CALLS.fetch_add(1, Ordering::Relaxed); GPU_MERKLE_TREE_CALLS.fetch_add(1, Ordering::Relaxed); - let (nodes_bytes, handle, lde_u64) = math_cuda::lde::coset_lde_row_major_with_merkle_tree_keep( + // The keep path keeps the Merkle tree resident on device (in `handle.tree`). + let (handle, lde_u64) = math_cuda::lde::coset_lde_row_major_with_merkle_tree_keep( raw, n, m, @@ -487,11 +492,10 @@ where ) }; - let nodes: Vec<[u8; 32]> = nodes_bytes - .chunks_exact(32) - .map(|c| c.try_into().expect("32-byte chunk")) - .collect(); - let tree = MerkleTree::::from_precomputed_nodes(nodes)?; + // Root-only host tree: the device tree (`handle.tree`) holds the nodes and + // serves openings; only the commitment root lives on host. + let root = handle.tree.as_ref()?.root; + let tree = MerkleTree::::from_root(root); Some((tree, handle, lde_out)) } @@ -537,15 +541,15 @@ where GPU_LEAF_HASH_CALLS.fetch_add(1, Ordering::Relaxed); GPU_MERKLE_TREE_CALLS.fetch_add(1, Ordering::Relaxed); - let (nodes_bytes, handle, lde_u64) = - math_cuda::lde::coset_lde_ext3_row_major_with_merkle_tree_keep( - raw, - n, - m, - blowup_factor, - &weights_u64, - ) - .ok()?; + // The keep path keeps the Merkle tree resident on device (in `handle.tree`). + let (handle, lde_u64) = math_cuda::lde::coset_lde_ext3_row_major_with_merkle_tree_keep( + raw, + n, + m, + blowup_factor, + &weights_u64, + ) + .ok()?; // Transmute Vec → Vec> (zero-copy, E == Fp3 = [u64;3]). let lde_out: Vec> = unsafe { @@ -561,11 +565,10 @@ where ) }; - let nodes: Vec<[u8; 32]> = nodes_bytes - .chunks_exact(32) - .map(|c| c.try_into().expect("32-byte chunk")) - .collect(); - let tree = MerkleTree::::from_precomputed_nodes(nodes)?; + // Root-only host tree: the device tree (`handle.tree`) holds the nodes and + // serves openings; only the commitment root lives on host. + let root = handle.tree.as_ref()?.root; + let tree = MerkleTree::::from_root(root); Some((tree, handle, lde_out)) } @@ -744,7 +747,7 @@ where /// recomputes on CPU. pub(crate) fn try_build_comp_poly_tree_gpu( lde_parts: &[Vec>], -) -> Option> +) -> Option<(MerkleTree, math_cuda::lde::GpuMerkleTree)> where E: IsField + 'static, B: IsMerkleTreeBackend, @@ -777,29 +780,17 @@ where }) .collect(); - let nodes_bytes = match math_cuda::merkle::build_comp_poly_tree_from_evals_ext3(&raw_parts) { - Ok(v) => v, + // Keep the composition tree resident on device, so the whole tree copy to + // host is eliminated. R4 composition openings gather paths from the device + // tree (`gather_proofs_dev`); the returned host tree is root only. + let dev_tree = match math_cuda::merkle::build_comp_poly_tree_from_evals_ext3_keep(&raw_parts) { + Ok(t) => t, Err(_) => return None, }; - - // lde_size is an even power of two >= 2, so 2*num_leaves == lde_size and - // tight_total_nodes = lde_size - 1 >= 1. No overflow or underflow possible. - let tight_total_nodes = lde_size - 1; - let expected_byte_len = tight_total_nodes - .checked_mul(32) - .expect("comp-poly node byte length overflow"); - debug_assert_eq!(nodes_bytes.len(), expected_byte_len); - - let nodes: Vec<[u8; 32]> = nodes_bytes - .chunks_exact(32) - .map(|c| { - c.try_into() - .expect("chunks_exact(32) yields exactly 32 bytes") - }) - .collect(); + debug_assert_eq!(dev_tree.leaves_len, lde_size / 2); GPU_COMP_POLY_TREE_CALLS.fetch_add(1, Ordering::Relaxed); - // Falls back to CPU on `None`, matching the R1 paths (lines 496, 557). - MerkleTree::::from_precomputed_nodes(nodes) + let host = MerkleTree::::from_root(dev_tree.root); + Some((host, dev_tree)) } /// R3 GPU dispatch: batched strided barycentric OOD evaluation over the main @@ -1424,18 +1415,67 @@ pub(crate) fn try_inv_denoms_dev_with_stream( coset_base: &[FieldElement], z_scalars: &[FieldElement], sign: math_cuda::inverse::DenomSign, + bound_stream: Option>, ) -> Option<(CudaSlice, Arc)> where F: IsField + 'static, E: IsField + 'static, { - let be = math_cuda::device::backend().ok()?; - let stream = be.next_stream(); + // Use the caller's per-table session stream when provided, so this table's + // R3/R4 device chain serialises on one queue; otherwise grab a pool stream. + let stream = match bound_stream { + Some(s) => s, + None => math_cuda::device::backend().ok()?.next_stream(), + }; let handle = try_compute_and_invert_inv_denoms_dev::(coset_base, z_scalars, sign, &stream)?; Some((handle, stream)) } +/// Gather Merkle authentication paths on device for `positions` (leaf indices), +/// returning one [`Proof`] per position in the same order. Byte-identical to +/// the host `MerkleTree::get_proof_by_pos` (guarded by the `merkle_gather` +/// parity test), so R4 query openings can source proofs from the resident +/// device tree instead of the host tree. Returns `None` on any cudarc error +/// (the caller then falls back to the host tree). +pub(crate) fn gather_proofs_dev( + tree: &math_cuda::lde::GpuMerkleTree, + positions: &[usize], + stream: &Arc, +) -> Option>> { + if positions.is_empty() { + return Some(Vec::new()); + } + // Positions index an LDE that `assert_u32_domain` keeps within u32; guard the + // cast so any future relaxation fails loudly instead of wrapping silently. + debug_assert!( + positions.iter().all(|&p| p <= u32::MAX as usize), + "gather_proofs_dev: position exceeds u32 range" + ); + let positions_u32: Vec = positions.iter().map(|&p| p as u32).collect(); + let bytes = math_cuda::merkle::gather_merkle_paths_dev( + &tree.nodes, + tree.leaves_len, + &positions_u32, + stream, + ) + .ok()?; + let depth = tree.leaves_len.trailing_zeros() as usize; + debug_assert_eq!(bytes.len(), positions.len() * depth * 32); + let mut proofs = Vec::with_capacity(positions.len()); + for q in 0..positions.len() { + let mut merkle_path = Vec::with_capacity(depth); + for level in 0..depth { + let off = (q * depth + level) * 32; + let mut node: Commitment = [0u8; 32]; + node.copy_from_slice(&bytes[off..off + 32]); + merkle_path.push(node); + } + proofs.push(Proof { merkle_path }); + } + Some(proofs) +} + /// R3 OOD device-side context: bundles the inverted denominators, the /// coset_points upload (used by every barycentric kernel for this batch), /// and the stream so producer + consumers serialize naturally. Hoisting @@ -1459,6 +1499,7 @@ pub(crate) struct R3DevContext { pub(crate) fn try_prep_r3_dev_context( coset_base: &[FieldElement], z_scalars: &[FieldElement], + bound_stream: Option>, ) -> Option where F: IsField + 'static, @@ -1480,8 +1521,12 @@ where return None; } - let be = math_cuda::device::backend().ok()?; - let stream = be.next_stream(); + // Per-table session stream when provided (shares the queue with R4 DEEP for + // this table); otherwise a pool stream. + let stream = match bound_stream { + Some(s) => s, + None => math_cuda::device::backend().ok()?.next_stream(), + }; // SAFETY: F == Goldilocks per TypeId check; FieldElement is // #[repr(transparent)] over u64. @@ -1590,7 +1635,7 @@ where let zeta_ptr = &zeta as *const FieldElement as *const u64; let zeta_raw: [u64; 3] = unsafe { [*zeta_ptr, *zeta_ptr.add(1), *zeta_ptr.add(2)] }; - let (root, layer_evals_u64, nodes_bytes) = match state.fold_and_commit_layer(zeta_raw) { + let (layer_evals_u64, dev_tree) = match state.fold_and_commit_layer(zeta_raw) { Ok(v) => v, Err(_) => { *transcript = transcript_snapshot.clone(); @@ -1598,23 +1643,18 @@ where } }; - // Build the FriLayer: ext3 evals + Merkle tree from precomputed nodes. + // Build the FriLayer: ext3 evals and a root only host tree. The layer + // tree stays resident on device in `gpu_tree`; query openings gather + // paths from it via `gather_proofs_dev`. let evaluation = u64_to_ext3_vec::(&layer_evals_u64); - - debug_assert!(nodes_bytes.len().is_multiple_of(32)); - let nodes: Vec<[u8; 32]> = nodes_bytes - .chunks_exact(32) - .map(|c| c.try_into().expect("chunks_exact(32) yields 32 bytes")) - .collect(); - let merkle_tree = MerkleTree::>::from_precomputed_nodes(nodes) - .expect("FRI commit: precomputed nodes form a valid tree"); - - fri_layer_list.push(FriLayer::new(&evaluation, merkle_tree)); + let root = dev_tree.root; + let merkle_tree = MerkleTree::>::from_root(root); + let mut layer = FriLayer::new(&evaluation, merkle_tree); + layer.gpu_tree = Some(dev_tree); + fri_layer_list.push(layer); // >>>> Send commitment: [p_k] - let mut root_arr = [0u8; 32]; - root_arr.copy_from_slice(&root); - transcript.append_bytes(&root_arr); + transcript.append_bytes(&root); } // <<<< Receive challenge zeta_{n-1} @@ -1641,3 +1681,79 @@ where GPU_FRI_CALLS.fetch_add(1, Ordering::Relaxed); Some((last_value, fri_layer_list)) } + +/// GPU FRI query phase: gather each layer's paths on device instead of walking +/// host trees. For layer `l` and query `iota` the opened position is +/// `(iota >> l) >> 1`, matching [`crate::fri::query_phase`]. Paths for all +/// queries are gathered in one batched call per layer. The layer evaluations +/// (`evaluation[index ^ 1]`) are read from the host Vecs as before. +/// +/// Returns None when there are no layers or the layers are host trees (CPU +/// commit), so the caller falls back to the host walk. +pub(crate) fn try_fri_query_phase_gpu( + fri_layers: &[FriLayer>], + iotas: &[usize], +) -> Option>> +where + E: IsField, + FieldElement: AsBytes + Sync + Send, +{ + if fri_layers.is_empty() { + return None; + } + // The GPU FRI commit sets `gpu_tree` on every layer as a group; the CPU + // commit sets none. Host trees fall back to the host walk. When the layers + // are device resident the host trees are root only, so the gather below must + // succeed (a failure is a hard abort, not a silent walk). The residency is + // all or nothing; assert it so a future partial-build can never route a + // root-only layer through the host walk and ship empty proofs. + let first_resident = fri_layers[0].gpu_tree.is_some(); + debug_assert!( + fri_layers + .iter() + .all(|l| l.gpu_tree.is_some() == first_resident), + "FRI layer residency must be all or nothing" + ); + if !first_resident { + return None; + } + let stream = math_cuda::device::backend() + .expect("cuda backend for device-resident FRI query") + .next_stream(); + let num_layers = fri_layers.len(); + + // Batched gather: one call per layer over all queries. + let mut per_layer_proofs: Vec>> = Vec::with_capacity(num_layers); + for (l, layer) in fri_layers.iter().enumerate() { + let tree = layer + .gpu_tree + .as_ref() + .expect("FRI layers are device-resident as a group"); + let positions: Vec = iotas.iter().map(|&iota| (iota >> l) >> 1).collect(); + per_layer_proofs.push( + gather_proofs_dev(tree, &positions, &stream) + .expect("device FRI-layer gather failed; resident tree has no host fallback"), + ); + } + + // Reassemble per-query decommitments, matching the host walk's order. + let decommits = iotas + .iter() + .enumerate() + .map(|(q, &iota)| { + let mut layers_evaluations_sym = Vec::with_capacity(num_layers); + let mut layers_auth_paths = Vec::with_capacity(num_layers); + let mut index = iota; + for (l, layer) in fri_layers.iter().enumerate() { + layers_evaluations_sym.push(layer.evaluation[index ^ 1].clone()); + layers_auth_paths.push(per_layer_proofs[l][q].clone()); + index >>= 1; + } + FriDecommitment { + layers_auth_paths, + layers_evaluations_sym, + } + }) + .collect(); + Some(decommits) +} diff --git a/crypto/stark/src/instruments.rs b/crypto/stark/src/instruments.rs index aa5cc5436..f263558aa 100644 --- a/crypto/stark/src/instruments.rs +++ b/crypto/stark/src/instruments.rs @@ -1,7 +1,153 @@ use std::cell::RefCell; +use std::sync::Mutex; use std::sync::OnceLock; use std::sync::atomic::{AtomicU64, Ordering}; -use std::time::Duration; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +// Wall clock span timeline: the trustworthy per step latency breakdown. +// +// Spans open and close on the main thread at phase boundaries. They do not +// overlap and sum to their parent, so the tree is a true latency breakdown +// (unlike the accum_* thread local sub timers below, which sum per worker CPU +// time across rayon threads and can exceed 100%). A parallel region is one span +// around the blocking call; its internal split is reported separately as CPU +// time, never mixed into the wall tree. +// +// let _s = instruments::span("trace_build"); // RAII, stops on drop +// +// Instant::now() is about 20 ns, fine at phase granularity, not in per op loops. + +#[derive(Clone, Debug)] +pub struct SpanRecord { + pub label: &'static str, + pub depth: u16, + pub wall: Duration, + /// Open-order, so the tree reconstructs in start-order (records push on close). + pub order: u32, + /// Wall clock epoch (ns) when the span opened, for aligning with external + /// samplers (e.g. nvidia-smi GPU util) to attribute device busy time per step. + pub start_ns: u128, +} + +static TIMELINE: Mutex> = Mutex::new(Vec::new()); +static SPAN_ORDER: AtomicU64 = AtomicU64::new(0); + +thread_local! { + static SPAN_DEPTH: std::cell::Cell = const { std::cell::Cell::new(0) }; +} + +#[must_use] +pub struct SpanGuard { + label: &'static str, + depth: u16, + order: u32, + start: Instant, + start_ns: u128, +} + +/// Open a wall-clock span; records elapsed time when the guard drops. +pub fn span(label: &'static str) -> SpanGuard { + let depth = SPAN_DEPTH.with(|d| { + let v = d.get(); + d.set(v + 1); + v + }); + let order = SPAN_ORDER.fetch_add(1, Ordering::Relaxed) as u32; + let start_ns = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + SpanGuard { + label, + depth, + order, + start: Instant::now(), + start_ns, + } +} + +impl Drop for SpanGuard { + fn drop(&mut self) { + let wall = self.start.elapsed(); + SPAN_DEPTH.with(|d| d.set(d.get().saturating_sub(1))); + if let Ok(mut t) = TIMELINE.lock() { + t.push(SpanRecord { + label: self.label, + depth: self.depth, + wall, + order: self.order, + start_ns: self.start_ns, + }); + } + } +} + +/// Clear recorded spans. Call at the start of a measured prove. +pub fn reset_timeline() { + SPAN_ORDER.store(0, Ordering::Relaxed); + SPAN_DEPTH.with(|d| d.set(0)); + if let Ok(mut t) = TIMELINE.lock() { + t.clear(); + } +} + +/// Drain recorded spans, sorted in start-order (ready for the tree). +pub fn take_timeline() -> Vec { + let mut spans = TIMELINE + .lock() + .map(|mut t| std::mem::take(&mut *t)) + .unwrap_or_default(); + spans.sort_by_key(|s| s.order); + spans +} + +/// Indented wall-clock tree with % of the root span. +pub fn format_timeline(spans: &[SpanRecord]) -> String { + use std::fmt::Write; + if spans.is_empty() { + return String::new(); + } + let total_s = spans + .first() + .map(|s| s.wall.as_secs_f64()) + .unwrap_or(1e-9) + .max(1e-9); + let mut out = String::from("=== TIMELINE (wall-clock) ===\n"); + for s in spans { + let indent = " ".repeat(s.depth as usize); + let pct = 100.0 * s.wall.as_secs_f64() / total_s; + let _ = writeln!( + out, + "{:<42} {:>10.3?} {:>6.1}%", + format!("{indent}{}", s.label), + s.wall, + pct + ); + } + out +} + +/// JSON array of `{label, depth, wall_ns, order}` for diffing / plotting. +pub fn timeline_json(spans: &[SpanRecord]) -> String { + let mut out = String::from("["); + for (i, s) in spans.iter().enumerate() { + if i > 0 { + out.push(','); + } + // Escape the label so a quote or backslash cannot break the JSON. + let label = s.label.replace('\\', "\\\\").replace('"', "\\\""); + out.push_str(&format!( + "{{\"label\":\"{}\",\"depth\":{},\"wall_ns\":{},\"order\":{},\"start_ns\":{}}}", + label, + s.depth, + s.wall.as_nanos(), + s.order, + s.start_ns + )); + } + out.push(']'); + out +} static HEAP_READER: OnceLock Option> = OnceLock::new(); @@ -122,8 +268,8 @@ pub fn take_r1_sub() -> Round1SubOps { /// Reset all instrument state. Call at the start of `multi_prove` to avoid /// stale data from a previous run in the same process. /// -/// Note: thread-local stores (R2_SUB, R4_SUB, ROUND_SUB_OPS) are only cleared -/// for the calling thread. Rayon worker threads are not reset — stale data is +/// Note: thread local stores (R2_SUB, R4_SUB, ROUND_SUB_OPS) are only cleared +/// for the calling thread. Rayon worker threads are not reset, so stale data is /// possible if a previous run panicked without consuming stored values. /// In practice this is safe because store/take pairs always execute within the /// same rayon task closure. diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 2ce1cb855..cdf1cd1b2 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -43,6 +43,8 @@ use super::lookup::BusPublicInputs; use super::proof::stark::{DeepPolynomialOpening, MultiProof, StarkProof}; use super::trace::TraceTable; use super::traits::AIR; +#[cfg(feature = "cuda")] +use crypto::merkle_tree::proof::Proof; pub use crate::commitment::{keccak_leaves_bit_reversed, keccak_leaves_row_pair_bit_reversed}; @@ -422,6 +424,53 @@ pub fn table_parallelism() -> usize { } } +/// Heuristic peak device bytes for one table: co-resident LDE columns plus the +/// resident Merkle trees, with a scratch factor for NTT and leaf transients. A +/// deliberate over estimate for a safety ceiling, not a precise allocator. Pass +/// aux_cols == 0 when the aux LDE is not yet resident (R1 main commit). +fn estimate_table_vram_bytes(main_cols: usize, aux_cols: usize, lde_size: usize) -> u64 { + const BYTES_PER_BASE: u64 = 8; + const EXT3_BYTES: u64 = 24; + const SCRATCH_FACTOR: u64 = 2; + const RESIDENT_TREE_BYTES_PER_LDE: u64 = 256; + let lde = lde_size as u64; + let per_row = (main_cols as u64).saturating_mul(BYTES_PER_BASE) + + (aux_cols as u64).saturating_mul(EXT3_BYTES); + let lde_term = lde.saturating_mul(per_row).saturating_mul(SCRATCH_FACTOR); + let tree_term = lde.saturating_mul(RESIDENT_TREE_BYTES_PER_LDE); + lde_term.saturating_add(tree_term) +} + +/// Plan contiguous table chunks for parallel proving. A chunk grows until it +/// hits `k` tables or its summed VRAM estimate would exceed `budget`; a single +/// table larger than `budget` runs solo. With `budget == u64::MAX` (non-cuda, +/// or VRAM not binding) chunks fall back to fixed size `k`, identical to the +/// old `step_by(k)`, so scheduling and the proof are unchanged. Returns +/// `(start, end)` half open ranges covering `0..estimates.len()` in order. +fn plan_table_chunks(estimates: &[u64], k: usize, budget: u64) -> Vec<(usize, usize)> { + let n = estimates.len(); + let k = k.max(1); + let budget = budget as u128; + let mut chunks = Vec::new(); + let mut start = 0; + while start < n { + let mut end = start; + let mut acc: u128 = 0; + while end < n { + let next = estimates[end] as u128; + // Always admit at least one table per chunk (oversized → solo). + if end > start && (end - start >= k || acc + next > budget) { + break; + } + acc += next; + end += 1; + } + chunks.push((start, end)); + start = end; + } + chunks +} + /// A container for the results of the second round of the STARK Prove protocol. pub(crate) struct Round2 where @@ -434,13 +483,12 @@ where pub(crate) composition_poly_merkle_tree: BatchedMerkleTree, /// The commitment to the composition polynomial parts. pub(crate) composition_poly_root: Commitment, - /// Device-resident de-interleaved LDE handle from the R2 fused GPU path - /// (`try_evaluate_parts_on_lde_gpu_keep`). When present, R4 DEEP skips - /// the `num_parts * 3 * lde_size * 8` byte H2D and reads parts on - /// device. `None` when the GPU R2 path didn't run (number_of_parts <= 2, - /// below threshold, or any CPU fallback). + /// The composition Merkle tree kept resident on device (when the R2 GPU tree + /// path ran), so R4 openings gather paths on device instead of walking a host + /// tree. When set, `composition_poly_merkle_tree` is a root only placeholder. + /// `None` on the CPU path. #[cfg(feature = "cuda")] - pub(crate) gpu_composition_parts: Option, + pub(crate) gpu_composition_tree: Option, } /// A container for the results of the third round of the STARK Prove protocol. @@ -1097,7 +1145,7 @@ pub trait IsStarkProver< pub_inputs: &PI, domain: &Domain, twiddles: &LdeTwiddles, - round_1_result: &Round1, + round_1_result: &mut Round1, transition_coefficients: &[FieldElement], boundary_coefficients: &[FieldElement], ) -> Result, ProvingError> @@ -1197,37 +1245,55 @@ pub trait IsStarkProver< let t_sub = Instant::now(); // GPU fast path for the comp-poly Merkle commit: row-pair Keccak // leaves + device-side inner tree, both wrapping the host eval Vecs. + // GPU path keeps the composition tree resident on device (no whole tree + // copy) and returns a root only host tree. The device tree is threaded + // to R4 in `Round2.gpu_composition_tree`. #[cfg(feature = "cuda")] - let gpu_tree = crate::gpu_lde::try_build_comp_poly_tree_gpu::< - FieldExtension, - BatchedMerkleTreeBackend, - >(&lde_composition_poly_parts_evaluations); + let (composition_poly_merkle_tree, composition_poly_root, gpu_composition_tree) = + match crate::gpu_lde::try_build_comp_poly_tree_gpu::< + FieldExtension, + BatchedMerkleTreeBackend, + >(&lde_composition_poly_parts_evaluations) + { + Some((host_tree, dev_tree)) => { + let root = host_tree.root; + (host_tree, root, Some(dev_tree)) + } + None => { + let (tree, root) = crate::commitment::commit_bit_reversed( + &lde_composition_poly_parts_evaluations, + crate::commitment::ROWS_PER_LEAF, + ) + .ok_or(ProvingError::EmptyCommitment)?; + (tree, root, None) + } + }; #[cfg(not(feature = "cuda"))] - let gpu_tree: Option> = None; - - let (composition_poly_merkle_tree, composition_poly_root) = match gpu_tree { - Some(tree) => { - let root = tree.root; - (tree, root) - } - None => crate::commitment::commit_bit_reversed( + let (composition_poly_merkle_tree, composition_poly_root) = + crate::commitment::commit_bit_reversed( &lde_composition_poly_parts_evaluations, crate::commitment::ROWS_PER_LEAF, ) - .ok_or(ProvingError::EmptyCommitment)?, - }; + .ok_or(ProvingError::EmptyCommitment)?; #[cfg(feature = "instruments")] let merkle_dur = t_sub.elapsed(); #[cfg(feature = "instruments")] crate::instruments::store_r2_sub(constraints_dur, fft_dur, merkle_dur); + // Fold the R2 device composition parts handle into the session (resident + // R2 to R4). The host evaluations stay in `Round2` for R4 openings. + #[cfg(feature = "cuda")] + if let Some(handle) = gpu_composition_parts { + round_1_result.lde_trace.set_gpu_composition_parts(handle); + } + Ok(Round2 { lde_composition_poly_evaluations: lde_composition_poly_parts_evaluations, composition_poly_merkle_tree, composition_poly_root, #[cfg(feature = "cuda")] - gpu_composition_parts, + gpu_composition_tree, }) } @@ -1487,11 +1553,12 @@ pub trait IsStarkProver< &domain.lde_roots_of_unity_coset, &z_scalars, math_cuda::inverse::DenomSign::XMinusZ, + lde_trace.bound_stream(), ) && let Some(deep_evals) = crate::gpu_lde::try_deep_composition_gpu::( lde_trace, - round_2_result.gpu_composition_parts.as_ref(), + lde_trace.gpu_composition_parts(), &round_2_result.lde_composition_poly_evaluations, h_ood, &trace_ood_columns, @@ -1527,7 +1594,7 @@ pub trait IsStarkProver< if let Some(deep_evals) = crate::gpu_lde::try_deep_composition_gpu::( lde_trace, - round_2_result.gpu_composition_parts.as_ref(), + lde_trace.gpu_composition_parts(), &round_2_result.lde_composition_poly_evaluations, h_ood, &trace_ood_columns, @@ -1648,6 +1715,45 @@ pub trait IsStarkProver< } } + /// Like [`Self::open_composition_poly`] but uses a Merkle proof already + /// gathered from the resident device composition tree + /// ([`crate::gpu_lde::gather_proofs_dev`]) instead of walking a host tree. + /// Row-pair leaf: one proof at position `index` authenticates both rows. + #[cfg(feature = "cuda")] + fn open_composition_poly_with_proof( + proof: Proof, + lde_composition_poly_evaluations: &[Vec>], + index: usize, + ) -> PolynomialOpenings + where + FieldElement: AsBytes + Sync + Send, + FieldElement: AsBytes + Sync + Send, + { + let lde_composition_poly_parts_evaluation: Vec<_> = lde_composition_poly_evaluations + .iter() + .flat_map(|part| { + vec![ + part[reverse_index(index * 2, part.len() as u64)].clone(), + part[reverse_index(index * 2 + 1, part.len() as u64)].clone(), + ] + }) + .collect(); + + PolynomialOpenings { + proof, + evaluations: lde_composition_poly_parts_evaluation + .clone() + .into_iter() + .step_by(2) + .collect(), + evaluations_sym: lde_composition_poly_parts_evaluation + .into_iter() + .skip(1) + .step_by(2) + .collect(), + } + } + /// Computes values and validity proofs of the evaluations of trace polynomials at /// the FRI query challenge `challenge` and its symmetric counterpart. The caller /// supplies a `gather` closure that pulls the row data from the column-major LDE @@ -1676,6 +1782,31 @@ pub trait IsStarkProver< } } + /// Like [`Self::open_polys_with`], but uses a Merkle proof already gathered + /// from the resident device tree (see [`crate::gpu_lde::gather_proofs_dev`]) + /// instead of walking a host tree. Row-pair leaf: one proof at position + /// `challenge` authenticates both the queried row and its symmetric + /// counterpart. Evaluations still come from the host LDE columns via `gather`. + #[cfg(feature = "cuda")] + fn open_polys_with_proofs( + domain: &Domain, + proof: Proof, + challenge: usize, + gather: G, + ) -> PolynomialOpenings + where + C: IsField, + FieldElement: AsBytes + Sync + Send, + G: Fn(usize) -> Vec>, + { + let domain_size = domain.lde_roots_of_unity_coset.len() as u64; + PolynomialOpenings { + proof, + evaluations: gather(reverse_index(challenge * 2, domain_size)), + evaluations_sym: gather(reverse_index(challenge * 2 + 1, domain_size)), + } + } + /// Open the deep composition polynomial on a list of indexes and their symmetric elements. fn open_deep_composition_poly( domain: &Domain, @@ -1695,7 +1826,63 @@ pub trait IsStarkProver< let num_precomputed_cols = main_commit.num_precomputed_cols; let total_cols = lde_trace.num_main_cols(); - for index in indexes_to_open.iter() { + // R4 trace proofs from the resident device trees, gathered in one batch + // over all query positions instead of walking the host trees (byte + // identical to the host proofs, guarded by the `merkle_gather` test). + // `*_dev_proofs` is `Some` exactly when the tree is device resident (so + // the host tree is a root only placeholder). In that case the gather + // must succeed: there is no host tree to fall back to, so a gather error + // is a hard abort. When the tree is not device resident the value is + // `None` and the openings below walk the full host tree. + #[cfg(feature = "cuda")] + let main_dev_proofs: Option>> = if is_preprocessed { + None + } else { + lde_trace + .gpu_main() + .and_then(|h| h.tree.as_ref()) + .map(|tree| { + let stream = lde_trace + .bound_stream() + .expect("bound stream for device-resident main-tree opening"); + // Row-pair leaves: one proof per query at position `challenge`. + crate::gpu_lde::gather_proofs_dev(tree, indexes_to_open, &stream).expect( + "device main-tree gather failed; resident tree has no host fallback", + ) + }) + }; + + // Same for the aux trace tree, when it is device resident. + #[cfg(feature = "cuda")] + let aux_dev_proofs: Option>> = round_1_result + .aux + .as_ref() + .and_then(|_aux| lde_trace.gpu_aux().and_then(|h| h.tree.as_ref())) + .map(|tree| { + let stream = lde_trace + .bound_stream() + .expect("bound stream for device-resident aux-tree opening"); + // Row-pair leaves: one proof per query at position `challenge`. + crate::gpu_lde::gather_proofs_dev(tree, indexes_to_open, &stream) + .expect("device aux-tree gather failed; resident tree has no host fallback") + }); + + // Composition tree: openings open a single position `index` (row pair + // leaf), so gather one proof per query challenge from the device tree. + #[cfg(feature = "cuda")] + let comp_dev_proofs: Option>> = + round_2_result.gpu_composition_tree.as_ref().map(|tree| { + let stream = lde_trace + .bound_stream() + .expect("bound stream for device-resident composition-tree opening"); + crate::gpu_lde::gather_proofs_dev(tree, indexes_to_open, &stream).expect( + "device composition-tree gather failed; resident tree has no host fallback", + ) + }); + + for (qi, index) in indexes_to_open.iter().enumerate() { + #[cfg(not(feature = "cuda"))] + let _ = qi; // For preprocessed tables, open the main split (multiplicities only); // for normal tables, open all main columns. let main_trace_opening = if is_preprocessed { @@ -1703,9 +1890,24 @@ pub trait IsStarkProver< lde_trace.gather_main_row_range(row, num_precomputed_cols, total_cols) }) } else { - Self::open_polys_with(domain, &main_commit.tree, *index, |row| { - lde_trace.gather_main_row(row) - }) + #[cfg(feature = "cuda")] + { + if let Some(proofs) = &main_dev_proofs { + Self::open_polys_with_proofs(domain, proofs[qi].clone(), *index, |row| { + lde_trace.gather_main_row(row) + }) + } else { + Self::open_polys_with(domain, &main_commit.tree, *index, |row| { + lde_trace.gather_main_row(row) + }) + } + } + #[cfg(not(feature = "cuda"))] + { + Self::open_polys_with(domain, &main_commit.tree, *index, |row| { + lde_trace.gather_main_row(row) + }) + } }; // For preprocessed tables, also open the precomputed-columns tree. @@ -1715,16 +1917,52 @@ pub trait IsStarkProver< }) }); - let composition_openings = Self::open_composition_poly( - &round_2_result.composition_poly_merkle_tree, - &round_2_result.lde_composition_poly_evaluations, - *index, - ); + let composition_openings = { + #[cfg(feature = "cuda")] + { + if let Some(proofs) = &comp_dev_proofs { + Self::open_composition_poly_with_proof( + proofs[qi].clone(), + &round_2_result.lde_composition_poly_evaluations, + *index, + ) + } else { + Self::open_composition_poly( + &round_2_result.composition_poly_merkle_tree, + &round_2_result.lde_composition_poly_evaluations, + *index, + ) + } + } + #[cfg(not(feature = "cuda"))] + { + Self::open_composition_poly( + &round_2_result.composition_poly_merkle_tree, + &round_2_result.lde_composition_poly_evaluations, + *index, + ) + } + }; let aux_trace_polys = round_1_result.aux.as_ref().map(|aux| { - Self::open_polys_with(domain, &aux.tree, *index, |row| { - lde_trace.gather_aux_row(row) - }) + #[cfg(feature = "cuda")] + { + if let Some(proofs) = &aux_dev_proofs { + Self::open_polys_with_proofs(domain, proofs[qi].clone(), *index, |row| { + lde_trace.gather_aux_row(row) + }) + } else { + Self::open_polys_with(domain, &aux.tree, *index, |row| { + lde_trace.gather_aux_row(row) + }) + } + } + #[cfg(not(feature = "cuda"))] + { + Self::open_polys_with(domain, &aux.tree, *index, |row| { + lde_trace.gather_aux_row(row) + }) + } }); openings.push(DeepPolynomialOpening { @@ -1792,6 +2030,8 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let phase_start = Instant::now(); + #[cfg(feature = "instruments")] + let __sp = crate::instruments::span("r1_prepass"); // Deduplicate Domain + LdeTwiddles by (trace_length, blowup_factor, coset_offset). // Many tables share the same domain size (e.g., 7+ tables at 2^20). @@ -1834,6 +2074,33 @@ pub trait IsStarkProver< let k = table_parallelism().min(num_airs).max(1); + // VRAM budgeted admission. The budget caps the summed device working set + // of the tables proved concurrently so large blocks don't exhaust VRAM. + // It is an extra ceiling on top of `k` (it never raises concurrency). On + // non-cuda builds, or when the budget can't be queried, it is `u64::MAX` + // and chunking falls back to fixed size `k`. + #[cfg(feature = "cuda")] + let vram_budget = math_cuda::device::backend() + .map(|b| b.vram_budget_bytes()) + .unwrap_or(u64::MAX); + #[cfg(not(feature = "cuda"))] + let vram_budget = u64::MAX; + + // R1 main commit: only the main LDE and its Merkle scratch are resident, + // so the aux columns add nothing to this phase's working set. + let main_chunks = { + let estimates: Vec = air_trace_pairs + .iter() + .enumerate() + .map(|(idx, (_, trace, _))| { + let lde_size = + domains[idx].interpolation_domain_size * domains[idx].blowup_factor; + estimate_table_vram_bytes(trace.num_main_columns, 0, lde_size) + }) + .collect(); + plan_table_chunks(&estimates, k, vram_budget) + }; + // Spill main traces to mmap before Round 1 LDE. #[cfg(feature = "disk-spill")] if storage_mode == StorageMode::Disk { @@ -1845,6 +2112,8 @@ pub trait IsStarkProver< })?; } + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] let prepass_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -1860,6 +2129,8 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let phase_start = Instant::now(); + #[cfg(feature = "instruments")] + let __sp = crate::instruments::span("r1_main_commit"); let mut main_commits: Vec> = Vec::with_capacity(num_airs); let mut main_ldes: Vec<(Vec>, usize)> = Vec::with_capacity(num_airs); @@ -1870,8 +2141,7 @@ pub trait IsStarkProver< let mut main_gpu_handles: Vec> = Vec::with_capacity(num_airs); - for chunk_start in (0..num_airs).step_by(k) { - let chunk_end = (chunk_start + k).min(num_airs); + for &(chunk_start, chunk_end) in &main_chunks { let chunk_range = chunk_start..chunk_end; let chunk_results: Vec> = @@ -1910,6 +2180,8 @@ pub trait IsStarkProver< } } + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] let main_commits_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -1945,6 +2217,8 @@ pub trait IsStarkProver< // but outer parallelism over 12 tables also helps on high-core-count machines. #[cfg(feature = "instruments")] let phase_start = Instant::now(); + #[cfg(feature = "instruments")] + let __sp = crate::instruments::span("r1_aux_build"); #[cfg(feature = "parallel")] let aux_iter = air_trace_pairs.par_iter_mut(); @@ -1973,6 +2247,8 @@ pub trait IsStarkProver< })?; } + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] let aux_build_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -1984,6 +2260,8 @@ pub trait IsStarkProver< // Each table gets its own transcript fork. #[cfg(feature = "instruments")] let phase_start = Instant::now(); + #[cfg(feature = "instruments")] + let __sp = crate::instruments::span("r1_aux_commit"); // Pre-fork all transcripts (cheap, sequential — must match verifier ordering) let mut table_transcripts: Vec<_> = (0..num_airs) @@ -2011,8 +2289,28 @@ pub trait IsStarkProver< #[allow(clippy::type_complexity)] let mut aux_results: Vec> = Vec::with_capacity(num_airs); - for chunk_start in (0..num_airs).step_by(k) { - let chunk_end = (chunk_start + k).min(num_airs); + // R1 aux commit and rounds 2 to 4 share the peak working set: the main + // and aux LDEs are co-resident, plus the composition and Merkle + // transients (in the scratch factor). `num_aux_columns` is populated by + // the aux build above, so this estimate is accurate for both phases. + let peak_chunks = { + let estimates: Vec = air_trace_pairs + .iter() + .enumerate() + .map(|(idx, (_, trace, _))| { + let lde_size = + domains[idx].interpolation_domain_size * domains[idx].blowup_factor; + estimate_table_vram_bytes( + trace.num_main_columns, + trace.num_aux_columns, + lde_size, + ) + }) + .collect(); + plan_table_chunks(&estimates, k, vram_budget) + }; + + for &(chunk_start, chunk_end) in &peak_chunks { let chunk_range = chunk_start..chunk_end; #[allow(clippy::type_complexity)] @@ -2174,6 +2472,8 @@ pub trait IsStarkProver< }); } + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] let aux_commit_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -2194,6 +2494,8 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let phase_start = Instant::now(); #[cfg(feature = "instruments")] + let __sp = crate::instruments::span("rounds_2to4"); + #[cfg(feature = "instruments")] let mut table_timings: Vec<( String, usize, @@ -2203,8 +2505,7 @@ pub trait IsStarkProver< let mut proofs = Vec::with_capacity(num_airs); let mut lde_drain = cached_ldes.into_iter(); - for chunk_start in (0..num_airs).step_by(k) { - let chunk_end = (chunk_start + k).min(num_airs); + for &(chunk_start, chunk_end) in &peak_chunks { let chunk_size = chunk_end - chunk_start; let chunk_ldes: Vec> = @@ -2236,7 +2537,7 @@ pub trait IsStarkProver< let table_start = Instant::now(); // Build Round1 from cached LDE (consumed by value, no recomputation). - let round_1_result = + let mut round_1_result = commitment.build_round1(lde, air.step_size(), domain.blowup_factor); if let Some(ref bpi) = round_1_result.bus_public_inputs { @@ -2246,7 +2547,7 @@ pub trait IsStarkProver< let proof = Self::prove_rounds_2_to_4( *air, *pub_inputs, - &round_1_result, + &mut round_1_result, table_transcript, domain, &twiddle_caches[idx], @@ -2282,6 +2583,8 @@ pub trait IsStarkProver< } } + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] { // Store timing data for the top-level report in prove_with_options. @@ -2334,7 +2637,7 @@ pub trait IsStarkProver< fn prove_rounds_2_to_4( air: &dyn AIR, pub_inputs: &PI, - round_1_result: &Round1, + round_1_result: &mut Round1, transcript: &mut (impl IsStarkTranscript + Clone), domain: &Domain, twiddles: &LdeTwiddles, diff --git a/crypto/stark/src/trace.rs b/crypto/stark/src/trace.rs index 72b77947a..0782ea245 100644 --- a/crypto/stark/src/trace.rs +++ b/crypto/stark/src/trace.rs @@ -9,6 +9,8 @@ use math::spill_safe::SpillSafe; use rayon::prelude::{ IndexedParallelIterator, IntoParallelIterator, ParallelIterator, ParallelSliceMut, }; +#[cfg(feature = "cuda")] +use std::sync::{Arc, OnceLock}; /// A two-dimensional representation of an execution trace of the STARK /// protocol. @@ -210,13 +212,44 @@ where pub(crate) num_rows: usize, pub(crate) lde_step_size: usize, pub(crate) blowup_factor: usize, - /// If the main trace was LDE'd on the GPU via the fused pipeline, the - /// device buffer is retained here so downstream GPU rounds can read the - /// LDE without a re-H2D. `None` on any CPU path. + /// Per table GPU residency session: owns this table's device LDE buffers + /// and bound stream. Threaded R1 to R4. Empty on the CPU path. #[cfg(feature = "cuda")] - pub(crate) gpu_main: Option, - #[cfg(feature = "cuda")] - pub(crate) gpu_aux: Option, + pub(crate) gpu_session: GpuTableSession, +} + +/// Per table GPU residency session. +/// +/// Owns the device buffers for one trace table: the main and aux trace LDE +/// (resident R1 to R4), the composition parts LDE (R2 to R4), and a bound +/// stream. The R4 local inv_denoms and FRI state stay local to R4. +#[cfg(feature = "cuda")] +pub(crate) struct GpuTableSession { + /// Main trace LDE, resident from the R1 fused pipeline through R4. None + /// when the GPU LDE did not run (below threshold, preprocessed main, not + /// Goldilocks, or a GPU error). + main_lde: Option, + /// Aux trace LDE (ext3, deinterleaved on device), resident R1 to R4. + aux_lde: Option, + /// Composition parts LDE (ext3, deinterleaved on device), produced in R2 + /// and resident R2 to R4 so R4 DEEP reads them on device. None when the R2 + /// GPU path did not run. + composition_parts: Option, + /// Stream bound to this table's GPU work, acquired lazily from the backend + /// pool and cached. None is cached when the backend is unavailable. + stream: OnceLock>>, +} + +#[cfg(feature = "cuda")] +impl GpuTableSession { + fn new() -> Self { + Self { + main_lde: None, + aux_lde: None, + composition_parts: None, + stream: OnceLock::new(), + } + } } impl LDETraceTable @@ -311,9 +344,7 @@ where lde_step_size, blowup_factor, #[cfg(feature = "cuda")] - gpu_main: None, - #[cfg(feature = "cuda")] - gpu_aux: None, + gpu_session: GpuTableSession::new(), } } @@ -348,34 +379,54 @@ where lde_step_size, blowup_factor, #[cfg(feature = "cuda")] - gpu_main: None, - #[cfg(feature = "cuda")] - gpu_aux: None, + gpu_session: GpuTableSession::new(), } } - /// Attach an already-populated device LDE handle for the main columns. - /// Only set when the GPU fused pipeline produced the LDE. Callers that - /// ran the CPU path should leave this alone. + /// Attach the device LDE handle for the main columns, produced by the GPU + /// fused pipeline. Leave unset on the CPU path. #[cfg(feature = "cuda")] pub fn set_gpu_main(&mut self, h: math_cuda::lde::GpuLdeBase) { - self.gpu_main = Some(h); + self.gpu_session.main_lde = Some(h); } /// Attach an already-populated device LDE handle for the aux columns. #[cfg(feature = "cuda")] pub fn set_gpu_aux(&mut self, h: math_cuda::lde::GpuLdeExt3) { - self.gpu_aux = Some(h); + self.gpu_session.aux_lde = Some(h); } #[cfg(feature = "cuda")] pub fn gpu_main(&self) -> Option<&math_cuda::lde::GpuLdeBase> { - self.gpu_main.as_ref() + self.gpu_session.main_lde.as_ref() } #[cfg(feature = "cuda")] pub fn gpu_aux(&self) -> Option<&math_cuda::lde::GpuLdeExt3> { - self.gpu_aux.as_ref() + self.gpu_session.aux_lde.as_ref() + } + + /// Attach the composition parts LDE produced in R2. Read by R4 DEEP so the + /// parts are not re-uploaded. + #[cfg(feature = "cuda")] + pub fn set_gpu_composition_parts(&mut self, h: math_cuda::lde::GpuLdeExt3) { + self.gpu_session.composition_parts = Some(h); + } + + #[cfg(feature = "cuda")] + pub fn gpu_composition_parts(&self) -> Option<&math_cuda::lde::GpuLdeExt3> { + self.gpu_session.composition_parts.as_ref() + } + + /// The stream bound to this table's GPU work. Acquired lazily from the + /// backend pool on first call and cached, so all of a table's stream ops + /// share one queue. Returns None (cached) when the backend is unavailable. + #[cfg(feature = "cuda")] + pub fn bound_stream(&self) -> Option> { + self.gpu_session + .stream + .get_or_init(|| math_cuda::device::backend().ok().map(|b| b.next_stream())) + .clone() } pub fn num_main_cols(&self) -> usize { @@ -495,7 +546,11 @@ where // both via offset, with no per-eval-point or per-{main,aux} H2D. #[cfg(feature = "cuda")] let r3_ctx: Option = - crate::gpu_lde::try_prep_r3_dev_context::(&dc.points, &evaluation_points); + crate::gpu_lde::try_prep_r3_dev_context::( + &dc.points, + &evaluation_points, + lde_trace.bound_stream(), + ); #[allow(unused_variables)] #[cfg(not(feature = "cuda"))] let r3_ctx: Option<()> = None; diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 760383003..6bbde8b84 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -781,11 +781,17 @@ pub fn prove_with_options_and_inputs( #[cfg(feature = "instruments")] let total_start = std::time::Instant::now(); #[cfg(feature = "instruments")] + stark::instruments::reset_timeline(); + #[cfg(feature = "instruments")] + let __root = stark::instruments::span("prove_total"); + #[cfg(feature = "instruments")] let heap_before = stark::instruments::heap_bytes(); // Phase 1: Execute (ELF load + run) #[cfg(feature = "instruments")] let phase_start = std::time::Instant::now(); + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("execute"); let program = Elf::load(elf_bytes).map_err(|e| Error::ElfLoad(format!("{e}")))?; let executor = Executor::new(&program, private_inputs.to_vec()) @@ -794,6 +800,8 @@ pub fn prove_with_options_and_inputs( .run() .map_err(|e| Error::Execution(format!("{e}")))?; + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] let execute_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -802,6 +810,8 @@ pub fn prove_with_options_and_inputs( // Phase 2: Trace build #[cfg(feature = "instruments")] let phase_start = std::time::Instant::now(); + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("trace_build"); #[cfg(feature = "disk-spill")] let storage_mode = { @@ -823,6 +833,8 @@ pub fn prove_with_options_and_inputs( ); drop(result); + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] let trace_build_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -831,6 +843,8 @@ pub fn prove_with_options_and_inputs( // Phase 3: AIR construction #[cfg(feature = "instruments")] let phase_start = std::time::Instant::now(); + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("air_construction"); let table_counts = traces.table_counts(); let airs = VmAirs::new( @@ -846,6 +860,8 @@ pub fn prove_with_options_and_inputs( None, ); + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] let air_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -873,6 +889,8 @@ pub fn prove_with_options_and_inputs( ); // Phase 4: Prove (multi_prove) + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("proving"); let proof = Prover::multi_prove( airs.air_trace_pairs(&mut traces), &mut transcript, @@ -880,6 +898,8 @@ pub fn prove_with_options_and_inputs( storage_mode, ) .map_err(|e| Error::Prover(format!("{e:?}")))?; + #[cfg(feature = "instruments")] + drop(__sp); #[cfg(feature = "instruments")] { @@ -895,6 +915,14 @@ pub fn prove_with_options_and_inputs( after_air: heap_after_air, }, ); + // Accurate wall-clock span tree (the trustworthy per-step breakdown). + drop(__root); + let spans = stark::instruments::take_timeline(); + print!("{}", stark::instruments::format_timeline(&spans)); + if let Ok(path) = std::env::var("LAMBDA_VM_TIMELINE_JSON") { + let _ = std::fs::write(&path, stark::instruments::timeline_json(&spans)); + println!("[timeline] wrote {path}"); + } } Ok(VmProof { diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index f3ca090d7..93f3ba563 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -2852,6 +2852,8 @@ fn build_traces( // ===================================================================== // PHASE 4: All → Bitwise lookups // ===================================================================== + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("p4_bitwise_collect"); bitwise_ops.extend(collect_bitwise_from_lt(<_ops)); // MUL/DVRM dedup their per-unique bit-gated lookups PER CHIP INSTANCE, so pass // the same chunk size used to split them into instances (see chunk_and_generate @@ -2905,10 +2907,14 @@ fn build_traces( .map(|chunk| chunk.len().next_power_of_two().max(4) - chunk.len()) .sum(); bitwise_ops.extend(collect_byte_check_ops_for_padding(num_padding_rows)); + #[cfg(feature = "instruments")] + drop(__sp); // ===================================================================== // PHASE 5: Generate final traces (parallelized) // ===================================================================== + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("p5_generate_tables"); // A monolithic run or the final continuation epoch terminates on the program's // halt ECALL. Intermediate continuation epochs do not halt, so fall back to the @@ -3275,6 +3281,8 @@ fn build_traces( }; let local_to_global = local_to_global::generate_local_to_global_trace(&[]); + #[cfg(feature = "instruments")] + drop(__sp); Ok(Traces { cpus, bitwise, @@ -3995,16 +4003,26 @@ impl Traces { // Phase 0: ELF → DECODE + instructions // IMPORTANT: Use generate_decode_trace (same as compute_precomputed_commitment) // so the DECODE trace row ordering matches the AIR's hardcoded commitment. + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("p0_decode"); let instructions = decode::instructions_from_elf(elf) .map_err(|e| Error::Execution(format!("Failed to parse instructions: {e}")))?; let (decode_trace, decode_pc_to_row) = decode::generate_decode_trace(&instructions); + #[cfg(feature = "instruments")] + drop(__sp); // Phase 1: Logs → CPU operations + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("p1_cpu_ops"); let cpu_ops = collect_cpu_ops(logs, &instructions)?; + #[cfg(feature = "instruments")] + drop(__sp); // Phase 2: Collect + route all ops let mut memory_state = MemoryState::from_image(initial_image); let mut register_state = RegisterState::from_init(register_init); + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("p2a_collect_cpu"); let ( memw_ops, load_ops, @@ -4018,7 +4036,11 @@ impl Traces { ec_scalar_ops, ecdas_ops, ) = collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); + #[cfg(feature = "instruments")] + drop(__sp); + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("p2b_collect_all"); let ops = collect_all_ops( cpu_ops, memw_ops, @@ -4035,9 +4057,13 @@ impl Traces { &mut register_state, is_final, ); + #[cfg(feature = "instruments")] + drop(__sp); // Phases 3-5 - build_traces( + #[cfg(feature = "instruments")] + let __sp = stark::instruments::span("p3to5_build_traces"); + let result = build_traces( ops, Some(initial_image), &memory_state, @@ -4051,7 +4077,10 @@ impl Traces { private_input, is_final, l2g_memory_bookend, - ) + ); + #[cfg(feature = "instruments")] + drop(__sp); + result } /// Generates all traces from execution logs (legacy API).