Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions crypto/crypto/src/merkle_tree/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@ where
})
}

/// Create a root only Merkle tree placeholder: stores the commitment root
/// but no nodes. Used when paths are gathered from a device resident copy
/// (GPU) instead of this host tree, so the host nodes are never built.
/// [`get_proof_by_pos`](Self::get_proof_by_pos) must NOT be called on it.
pub fn from_root(root: B::Node) -> Self {
MerkleTree {
root,
nodes: Vec::new(),
#[cfg(feature = "disk-spill")]
mmap_backing: None,
}
}

/// Create a Merkle tree from pre-hashed leaf nodes.
///
/// This skips the `hash_leaves` step, useful when leaves have already been
Expand Down
34 changes: 34 additions & 0 deletions crypto/math-cuda/kernels/keccak.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,40 @@ extern "C" __global__ void keccak_merkle_level(
finalize_keccak256(st, rate_pos, nodes + (parent_begin + tid) * 32);
}

// Gather Merkle authentication paths for a batch of leaf positions, reading the
// resident tree `nodes` (32-byte nodes; layout: inner nodes [0..leaves_len-1],
// root at 0, leaves at [leaves_len-1..]). One thread per query walks leaf->root,
// writing each sibling node into the output. This mirrors the CPU
// `build_merkle_path` exactly (sibling_index / parent_index in
// crypto/crypto/src/merkle_tree/utils.rs):
// leaf node = pos + leaves_len - 1
// sibling = node even ? node-1 : node+1
// parent = node even ? (node-1)/2 : node/2
// so `out[(q*depth + level)*32 .. +32]` is the level-th sibling for query q.
extern "C" __global__ void merkle_gather_paths(
const uint8_t *nodes,
const uint32_t *positions, // leaf positions, length num_queries
uint32_t num_queries,
uint64_t leaves_len,
uint32_t depth, // = log2(leaves_len)
uint8_t *out) { // num_queries * depth * 32 bytes
uint32_t q = blockIdx.x * blockDim.x + threadIdx.x;
if (q >= num_queries) return;

uint64_t node = (uint64_t)positions[q] + leaves_len - 1;
for (uint32_t level = 0; level < depth; ++level) {
uint64_t sib = (node & 1ull) ? (node + 1ull) : (node - 1ull);
// 32-byte nodes at 32-byte-aligned offsets (cuMemAlloc 256-aligned),
// so the u64 copy is safe.
const uint64_t *src = reinterpret_cast<const uint64_t *>(nodes + sib * 32);
uint64_t *dst = reinterpret_cast<uint64_t *>(
out + ((uint64_t)q * depth + level) * 32);
#pragma unroll
for (int i = 0; i < 4; ++i) dst[i] = src[i];
node = (node & 1ull) ? (node >> 1) : ((node - 1ull) >> 1);
}
}

// ---------------------------------------------------------------------------
// Row-major ROW-PAIR leaf hashing.
//
Expand Down
93 changes: 93 additions & 0 deletions crypto/math-cuda/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ pub struct Backend {
pinned_hashes: Vec<Mutex<PinnedStaging>>,
util_stream: Arc<CudaStream>,
next: AtomicUsize,
/// VRAM budget (bytes) for table-session admission control. See
/// [`detect_vram_budget_bytes`].
vram_budget_bytes: u64,

// arith.ptx
pub vector_add_u64: CudaFunction,
Expand Down Expand Up @@ -154,6 +157,7 @@ pub struct Backend {
pub keccak_comp_poly_leaves_ext3: CudaFunction,
pub keccak_fri_leaves_ext3: CudaFunction,
pub keccak_merkle_level: CudaFunction,
pub merkle_gather_paths: CudaFunction,

// barycentric.ptx
pub barycentric_base_batched: CudaFunction,
Expand Down Expand Up @@ -181,6 +185,74 @@ pub struct Backend {
inv_twiddles: Mutex<Vec<Option<Arc<CudaSlice<u64>>>>>,
}

/// Raise the device default memory pool's release threshold so freed
/// stream-ordered allocations are kept for reuse instead of returned to the OS
/// at each sync. Best-effort: any failure (e.g. a device/driver without
/// stream-ordered allocator support) leaves the default behaviour untouched.
fn retain_default_mempool(ctx: &CudaContext) {
use cudarc::driver::sys;
// SAFETY: raw CUDA driver calls. `ctx.cu_device()` is a valid device for
// the just-created context; the out-pointers are valid stack slots; the
// threshold is read as a u64 by the driver. Errors are swallowed.
unsafe {
let dev = ctx.cu_device();
let mut pool: sys::CUmemoryPool = std::ptr::null_mut();
if sys::cuDeviceGetDefaultMemPool(&mut pool as *mut _, dev)
.result()
.is_err()
{
return;
}
// Default: retain freed stream-ordered blocks indefinitely (u64::MAX)
// for reuse. `LAMBDA_VM_MEMPOOL_RELEASE_MB` overrides the cap (bytes the
// pool keeps before returning memory to the OS) when retained-pool
// growth needs bounding.
let threshold: u64 = std::env::var("LAMBDA_VM_MEMPOOL_RELEASE_MB")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.map(|mb| mb.saturating_mul(1024 * 1024))
.unwrap_or(u64::MAX);
Comment thread
ColoCarletti marked this conversation as resolved.
let _ = sys::cuMemPoolSetAttribute(
pool,
sys::CUmemPool_attribute_enum::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD,
&threshold as *const u64 as *mut core::ffi::c_void,
)
.result();
}
}

/// Device VRAM budget in bytes for table session admission control.
///
/// LAMBDA_VM_VRAM_BUDGET_MB overrides it (used to force the throttle in tests).
/// Otherwise it is 80% of total device memory, leaving headroom for the
/// context, module code, and retained pool blocks. Returns u64::MAX on any
/// query failure, which disables budgeting (chunks fall back to the core bound
/// size alone).
fn detect_vram_budget_bytes(ctx: &CudaContext) -> u64 {
if let Ok(mb) = std::env::var("LAMBDA_VM_VRAM_BUDGET_MB")
&& let Ok(mb) = mb.parse::<u64>()
{
return mb.saturating_mul(1024 * 1024);
}
use cudarc::driver::sys;
// SAFETY: raw driver query writing into two stack slots. The caller's
// context is already current (it was just created in `init`). Any error
// falls through to the budgeting-disabled sentinel.
unsafe {
let _ = ctx;
let mut free: usize = 0;
let mut total: usize = 0;
if sys::cuMemGetInfo_v2(&mut free as *mut usize, &mut total as *mut usize)
.result()
.is_err()
{
return u64::MAX;
}
// 80% of total, computed to avoid intermediate overflow.
(total as u64) / 5 * 4
}
}

impl Backend {
fn init() -> Result<Self> {
let ctx = CudaContext::new(0)?;
Expand All @@ -190,6 +262,17 @@ impl Backend {
// before returning), so the tracking is pure overhead. Disable it.
unsafe { ctx.disable_event_tracking() };

// Retain freed device memory in the stream ordered pool for reuse.
//
// cudarc routes CudaStream::alloc* through cuMemAllocAsync, drawing from
// the device default memory pool. Its release threshold defaults to 0,
// so every freed buffer goes back to the OS at the next sync and the
// prover's large LDE/FRI buffers are rebuilt from scratch each op.
// Raising the threshold keeps freed blocks in the pool so a same size
// allocation skips a real driver allocation. Best effort: on any error
// we keep the current behaviour.
retain_default_mempool(&ctx);

let arith = ctx.load_module(Ptx::from_src(ARITH_PTX))?;
let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?;
let keccak = ctx.load_module(Ptx::from_src(KECCAK_PTX))?;
Expand Down Expand Up @@ -225,6 +308,8 @@ impl Backend {
// Length = TWO_ADICITY + 1 to allow indexing at log_n = TWO_ADICITY.
let max_log = GoldilocksField::TWO_ADICITY as usize + 1;

let vram_budget_bytes = detect_vram_budget_bytes(&ctx);

Ok(Self {
vector_add_u64: arith.load_function("vector_add_u64")?,
gl_add: arith.load_function("gl_add_kernel")?,
Expand Down Expand Up @@ -257,6 +342,7 @@ impl Backend {
keccak_comp_poly_leaves_ext3: keccak.load_function("keccak_comp_poly_leaves_ext3")?,
keccak_fri_leaves_ext3: keccak.load_function("keccak_fri_leaves_ext3")?,
keccak_merkle_level: keccak.load_function("keccak_merkle_level")?,
merkle_gather_paths: keccak.load_function("merkle_gather_paths")?,
barycentric_base_batched: bary.load_function("barycentric_base_batched")?,
barycentric_ext3_batched: bary.load_function("barycentric_ext3_batched")?,
barycentric_base_batched_strided: bary
Expand All @@ -282,9 +368,16 @@ impl Backend {
pinned_hashes,
util_stream,
next: AtomicUsize::new(0),
vram_budget_bytes,
})
}

/// VRAM budget in bytes for table-session admission control. `u64::MAX`
/// when budgeting is disabled (query failed). See the field docs.
pub fn vram_budget_bytes(&self) -> u64 {
self.vram_budget_bytes
}

/// Round-robin over the stream pool. Concurrent callers get different
/// streams so their kernel launches overlap on the GPU.
pub fn next_stream(&self) -> Arc<CudaStream> {
Expand Down
21 changes: 13 additions & 8 deletions crypto/math-cuda/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl FriCommitState {
pub fn fold_and_commit_layer(
&mut self,
zeta_raw: [u64; 3],
) -> Result<(Vec<u8>, Vec<u64>, Vec<u8>)> {
) -> Result<(Vec<u64>, crate::lde::GpuMerkleTree)> {
#[cfg(feature = "test-faults")]
check_fault_injection()?;
let be = backend()?;
Expand Down Expand Up @@ -214,17 +214,22 @@ impl FriCommitState {
self.stream.clone_dtoh(&view)?
};

// Tree nodes.
let nodes_bytes: Vec<u8> = self.stream.clone_dtoh(&nodes_dev)?;
debug_assert_eq!(nodes_bytes.len(), tight_total_nodes * 32);

let mut root = vec![0u8; 32];
root.copy_from_slice(&nodes_bytes[0..32]);
// Keep the layer tree resident on device; copy only the 32-byte root so
// R4 query openings gather paths on device instead of copying the tree.
let mut root = [0u8; 32];
self.stream
.memcpy_dtoh(&nodes_dev.slice(0..32), &mut root)?;
self.stream.synchronize()?;

self.a_is_input = !self.a_is_input;
self.current_n = n_out;

Ok((root, layer_evals, nodes_bytes))
let tree = crate::lde::GpuMerkleTree {
nodes: std::sync::Arc::new(nodes_dev),
leaves_len: num_leaves,
root,
};
Ok((layer_evals, tree))
}

/// Final fold, no Merkle commit. Returns the single ext3 output
Expand Down
Loading
Loading