Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 117 additions & 24 deletions zstd/src/decoding/literals_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,57 @@ struct LoopBounds {
alloc_upper_bound: usize,
}

/// One burst iteration's inner symbol loop, monomorphised on the
/// compile-time symbol count `SPB` so LLVM fully unrolls the
/// `SPB × 4` decode steps into straight-line code — matching donor's
/// `for symbol in 0..5` hardcoded unroll in
/// `HUF_decompress4X1_usingDTable_internal_fast_c_loop`. A runtime
/// `for _ in 0..symbols_per_burst` bound leaves an induction variable
/// and per-iteration trip check that the unrolled form eliminates;
/// on the literal-heavy decode path this inner loop is the dominant
/// decode cost, so the unroll is the single biggest burst win.
///
/// # Safety
/// Same preconditions as the caller's burst body: every `idx` is
/// `< packed.len()` (table-shift bounded), and every written
/// `cursors[s]` is `< alloc_upper_bound` (caller's lockstep gate +
/// `debug_assert!`). `target_ptr` backs an allocation covering those
/// indices.
#[inline(always)]
unsafe fn burst_decode_symbols<const SPB: usize>(
bits: &mut [u64; 4],
cursors: &mut [usize; 4],
target_ptr: *mut u8,
packed: &[u16],
table_shift: u32,
) {
for _ in 0..SPB {
let idx0 = (bits[0] >> table_shift) as usize;
let entry0 = unsafe { *packed.get_unchecked(idx0) };
unsafe { target_ptr.add(cursors[0]).write((entry0 & 0xFF) as u8) };
cursors[0] += 1;
bits[0] <<= (entry0 >> 8) & 0xFF;

let idx1 = (bits[1] >> table_shift) as usize;
let entry1 = unsafe { *packed.get_unchecked(idx1) };
unsafe { target_ptr.add(cursors[1]).write((entry1 & 0xFF) as u8) };
cursors[1] += 1;
bits[1] <<= (entry1 >> 8) & 0xFF;

let idx2 = (bits[2] >> table_shift) as usize;
let entry2 = unsafe { *packed.get_unchecked(idx2) };
unsafe { target_ptr.add(cursors[2]).write((entry2 & 0xFF) as u8) };
cursors[2] += 1;
bits[2] <<= (entry2 >> 8) & 0xFF;

let idx3 = (bits[3] >> table_shift) as usize;
let entry3 = unsafe { *packed.get_unchecked(idx3) };
unsafe { target_ptr.add(cursors[3]).write((entry3 & 0xFF) as u8) };
cursors[3] += 1;
bits[3] <<= (entry3 >> 8) & 0xFF;
}
}

/// Donor-parity 4-stream HUF decode burst loop. Single code path —
/// no kernel dispatch, no SIMD-fallback hybrid. Mirrors
/// `huf_decompress.c:HUF_decompress4X1_usingDTable_internal_fast_c_loop`:
Expand Down Expand Up @@ -691,30 +742,72 @@ unsafe fn run_4stream_burst_loop<K: CpuKernel>(
// `.write()` (raw store) is used instead of `&mut [u8]`
// indexing so no Rust reference is ever formed to the
// uninitialised tail before its byte is written.
for _ in 0..symbols_per_burst {
let idx0 = (bits[0] >> table_shift) as usize;
let entry0 = unsafe { *packed.get_unchecked(idx0) };
unsafe { target_ptr.add(cursors[0]).write((entry0 & 0xFF) as u8) };
cursors[0] += 1;
bits[0] <<= (entry0 >> 8) & 0xFF;

let idx1 = (bits[1] >> table_shift) as usize;
let entry1 = unsafe { *packed.get_unchecked(idx1) };
unsafe { target_ptr.add(cursors[1]).write((entry1 & 0xFF) as u8) };
cursors[1] += 1;
bits[1] <<= (entry1 >> 8) & 0xFF;

let idx2 = (bits[2] >> table_shift) as usize;
let entry2 = unsafe { *packed.get_unchecked(idx2) };
unsafe { target_ptr.add(cursors[2]).write((entry2 & 0xFF) as u8) };
cursors[2] += 1;
bits[2] <<= (entry2 >> 8) & 0xFF;

let idx3 = (bits[3] >> table_shift) as usize;
let entry3 = unsafe { *packed.get_unchecked(idx3) };
unsafe { target_ptr.add(cursors[3]).write((entry3 & 0xFF) as u8) };
cursors[3] += 1;
bits[3] <<= (entry3 >> 8) & 0xFF;
//
// Dispatch the inner loop on the compile-time symbol count so
// the dominant cases get a fully-unrolled body (donor's
// hardcoded `for symbol in 0..5`). `symbols_per_burst` is
// loop-invariant, so the match is hoisted out of the `while`
// by loop-unswitching; each arm monomorphises
// `burst_decode_symbols::<SPB>` to straight-line code. SPB=5
// covers `max_num_bits ∈ {10, 11}` — the large-alphabet
// literal-heavy case that dominates decode cost; 6 covers
// {8, 9}, 7 covers {7}. Rarer small-max tables (few symbols,
// cheap overall) fall through to the dynamic loop.
match symbols_per_burst {
5 => unsafe {
burst_decode_symbols::<5>(
&mut bits,
&mut *cursors,
target_ptr,
packed,
table_shift,
);
},
6 => unsafe {
burst_decode_symbols::<6>(
&mut bits,
&mut *cursors,
target_ptr,
packed,
table_shift,
);
},
7 => unsafe {
burst_decode_symbols::<7>(
&mut bits,
&mut *cursors,
target_ptr,
packed,
table_shift,
);
},
_ => {
for _ in 0..symbols_per_burst {
let idx0 = (bits[0] >> table_shift) as usize;
let entry0 = unsafe { *packed.get_unchecked(idx0) };
unsafe { target_ptr.add(cursors[0]).write((entry0 & 0xFF) as u8) };
cursors[0] += 1;
bits[0] <<= (entry0 >> 8) & 0xFF;

let idx1 = (bits[1] >> table_shift) as usize;
let entry1 = unsafe { *packed.get_unchecked(idx1) };
unsafe { target_ptr.add(cursors[1]).write((entry1 & 0xFF) as u8) };
cursors[1] += 1;
bits[1] <<= (entry1 >> 8) & 0xFF;

let idx2 = (bits[2] >> table_shift) as usize;
let entry2 = unsafe { *packed.get_unchecked(idx2) };
unsafe { target_ptr.add(cursors[2]).write((entry2 & 0xFF) as u8) };
cursors[2] += 1;
bits[2] <<= (entry2 >> 8) & 0xFF;

let idx3 = (bits[3] >> table_shift) as usize;
let entry3 = unsafe { *packed.get_unchecked(idx3) };
unsafe { target_ptr.add(cursors[3]).write((entry3 & 0xFF) as u8) };
cursors[3] += 1;
bits[3] <<= (entry3 >> 8) & 0xFF;
}
}
}

// Reload all 4 streams (donor `HUF_4X1_RELOAD_STREAM`).
Expand Down
Loading