diff --git a/CMakeLists.txt b/CMakeLists.txt index 922b04b89..629788d60 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -312,6 +312,7 @@ if(BUILD_CUDA) set_target_properties(bitsandbytes PROPERTIES CUDA_SEPARABLE_COMPILATION ON + CUDA_RESOLVE_DEVICE_SYMBOLS ON ) endif() if(BUILD_HIP) diff --git a/agents/flute_kernel_guide.md b/agents/flute_kernel_guide.md new file mode 100644 index 000000000..344a69b90 --- /dev/null +++ b/agents/flute_kernel_guide.md @@ -0,0 +1,1145 @@ +# FLUTE Kernel: Comprehensive Technical Guide + +This document provides a thorough analysis of the FLUTE (Flexible Lookup Table Engine) +kernel for lookup-table-quantized LLM inference. It covers the kernel architecture, +implementation details, performance characteristics, and relevance to the bitsandbytes +kbit GEMM kernel design. + +--- + +## Executive Summary: FLUTE vs. Bitsandbytes kbit + +FLUTE and the bitsandbytes kbit GEMM kernel are two different approaches to the same +problem — fused dequantization + matrix multiplication for lookup-table-quantized LLM +weights — with comparable instruction-level efficiency. + +**They are similar in:** +- Core operation: load compressed weights, dequant via codebook, tensor core MMA +- Instruction count per element: roughly comparable (~3-6 ops depending on bit width) +- Performance regime: both achieve 2-4x over FP16 at small batch, converging to dense + throughput at large batch (fundamental property of weight-only quantization) +- Both require offline weight repacking for GEMM-friendly tile layout + +**FLUTE trades flexibility for per-shape optimization:** +- Built on CUTLASS 3 / CuTe — gets multi-stage pipelining and Stream-K for free +- Requires per-(shape, bits, group_size, GPU) compilation and auto-tuning +- Shape-specialized binaries limit deployment flexibility +- CUTLASS dependency (pinned to v3.4.1) +- 3-bit uses bit-slice decomposition (1+2 split) — different code path, ~33% more + instructions than 4-bit +- No 5-bit support +- Focused on A100/A6000; RTX 4090 supported but less tuned + +**kbit trades CUTLASS infrastructure for simplicity and breadth:** +- Self-contained hand-written CUDA, no external dependencies +- Uniform code path for K=2,3,4,5 via bit-plane format — no special cases +- No per-shape recompilation or tuning needed +- Register-based codebook lookup via `__shfl_sync` (zero memory, 1 cycle) +- E4M4 absmax (1 byte per block of 32) — finer granularity than FLUTE's FP16 scales +- Developed and tested on RTX 4090; not yet tuned for data center GPUs + +**Bottom line:** FLUTE does not have a fundamental architectural advantage over the kbit +design. The two kernels have similar instruction-level efficiency with different +engineering trade-offs. FLUTE's head start is that it exists as a working fused GEMM +today and has been benchmarked on data center GPUs. Once the kbit GEMM is implemented +and tuned for A100/H100, there is no reason to expect FLUTE would be meaningfully +faster. The bitsandbytes ecosystem integration (Transformers, PEFT, Accelerate) and +broader bit-width support (K=2-5 uniform) are practical advantages that matter more +than marginal kernel-level performance differences. + +FLUTE has limited real-world adoption despite its EMNLP 2024 publication — it is not +a default in any major inference framework and has known issues (shape specialization, +numerical instability at some configurations, bfloat16 underperformance). It is best +understood as an academic contribution that validates the LUT-quantized GEMM approach, +not as a production system to compete against. + +--- + +## Table of Contents + +1. [Overview and Motivation](#1-overview-and-motivation) +2. [The Core Problem: LUT-Quantized GEMM on GPUs](#2-the-core-problem-lut-quantized-gemm-on-gpus) +3. [Three-Part Solution Architecture](#3-three-part-solution-architecture) +4. [Offline Weight Restructuring (Section 3.1)](#4-offline-weight-restructuring) +5. [Vectorized Lookup Table with Duplication (Section 3.2)](#5-vectorized-lookup-table-with-duplication) +6. [Stream-K Workload Partitioning (Section 3.3)](#6-stream-k-workload-partitioning) +7. [CUTLASS 3 / CuTe Implementation](#7-cutlass-3--cute-implementation) +8. [Source Code Structure](#8-source-code-structure) +9. [Kernel Configuration and Tuning](#9-kernel-configuration-and-tuning) +10. [NormalFloat and NFL (Learned NormalFloat)](#10-normalfloat-and-nfl-learned-normalfloat) +11. [Performance Analysis](#11-performance-analysis) +12. [Comparison with Other Kernels](#12-comparison-with-other-kernels) +13. [Relevance to Bitsandbytes kbit GEMM](#13-relevance-to-bitsandbytes-kbit-gemm) +14. [Limitations and Known Issues](#14-limitations-and-known-issues) +15. [Links and References](#15-links-and-references) + +--- + +## 1. Overview and Motivation + +**Paper**: "Fast Matrix Multiplications for Lookup Table-Quantized LLMs" +**Authors**: Han Guo, William Brandon, Radostin Cholakov, Jonathan Ragan-Kelley, +Eric P. Xing, Yoon Kim +**Published**: EMNLP 2024 (Findings) +**ArXiv**: 2407.10960 (v4, January 17, 2025) + +FLUTE is a CUDA kernel engine for efficient inference of weight-quantized LLMs where +the quantization is based on **lookup tables** (LUT) rather than uniform (linear) +integer quantization. This distinction is critical: + +- **Uniform quantization** (e.g., standard INT4): `dequant(q) = q * scale + zero` + Simple arithmetic, easily fused with GEMM. + +- **LUT quantization** (e.g., NF4, custom codebooks): `dequant(q) = table[q] * scale` + Requires a table lookup per element, which is fundamentally different from arithmetic + dequantization and presents unique GPU optimization challenges. + +FLUTE supports arbitrary lookup tables, making it compatible with: +- Integer quantization: int4, int3, int2 +- Floating-point: fp4, fp3, fp2 +- Normal float variants: nf4, nf3, nf2 +- Learned Normal Float (NFL): A learnable extension to QLoRA's nf4 +- Custom arbitrary tables (any 2^K values) + +At batch sizes < 32 with group size 128 (typical LLM inference), FLUTE achieves +**2-4x speedup** over existing GEMM kernels and **1.5-2x end-to-end throughput +improvement** on LLaMA-3 models. + +--- + +## 2. The Core Problem: LUT-Quantized GEMM on GPUs + +The paper identifies three fundamental challenges for building a high-performance +LUT-quantized matmul kernel on GPUs: + +### Challenge 1: Tensor Core Data Layout Requirements + +Tensor Cores have strict requirements on data types, shapes, and layouts. Quantized +weights at non-standard bit widths (especially 3-bit) cannot be packed evenly into +the 128-bit vectorized memory accesses that feed the tensor core pipeline. For +example: + +- 4-bit: 32 values per 128-bit word (clean) +- 3-bit: 42.67 values per 128-bit word (does not divide evenly) +- 2-bit: 64 values per 128-bit word (clean) + +The 3-bit case is problematic: you cannot load a clean set of 3-bit values with a +single 128-bit async copy instruction. + +### Challenge 2: Dynamic Indexing Limitations + +LUT-based dequantization requires dynamic indexing into a table. GPUs do not natively +support efficient dynamic indexing of data in their fastest on-chip storage (registers). +The alternatives are: + +- **Registers**: No dynamic indexing. Would need a switch/case statement. +- **Shared memory**: Supports dynamic indexing but has limited bandwidth (32 banks, + 32-bit each) and potential bank conflicts. +- **Constant memory**: Broadcasts to all threads if they access the same address, but + serializes if they access different addresses. + +Since each thread typically looks up a different index, shared memory is the natural +choice, but naive implementations suffer from bank conflicts. + +### Challenge 3: Wave Quantization at Small Problem Sizes + +With low-bit quantization and small batch sizes, the weight matrix is small, producing +fewer output tiles. If the number of tiles doesn't fill all SMs evenly, some SMs sit +idle in the last "wave" (wave quantization). This is a significant efficiency loss +for the small-matrix regime that LLM inference typically operates in. + +--- + +## 3. Three-Part Solution Architecture + +FLUTE addresses these challenges with three complementary techniques: + +1. **Offline weight restructuring** (Section 3.1): Reorder quantized weights at + model-load time so that after dequantization, the data is already in the layout + that tensor cores expect. This moves bit-manipulation overhead from runtime to + load time. + +2. **Vectorized and duplicated lookup table** (Section 3.2): Store the LUT in shared + memory, but access two values simultaneously (vectorization) and duplicate the + table across banks (duplication) to eliminate bank conflicts. + +3. **Stream-K workload partitioning** (Section 3.3): Use fine-grained work distribution + across SMs to minimize wave quantization effects. + +--- + +## 4. Offline Weight Restructuring + +### The Problem + +Consider 3-bit quantization. Each weight is a 3-bit index into a lookup table. +Packing these into 128-bit words for async copy: + +- 128 / 3 = 42.67 — doesn't divide evenly +- You can't load exactly N complete 3-bit values with a single vector load + +Standard approaches pad to 4 bits (wasting 25% of storage) or use complex runtime +bit manipulation to extract 3-bit fields from packed words. + +### FLUTE's Approach: Bit-Slice Decomposition + +FLUTE splits the 3-bit representation into two "bit-slices": +- A **1-bit partition** (the most significant bit) +- A **2-bit partition** (the two least significant bits) + +Each partition is stored separately and can be loaded with standard 128-bit async +copy instructions: +- The 1-bit partition: 128 values per 128-bit word +- The 2-bit partition: 64 values per 128-bit word + +After loading both slices into registers, they are combined via bit manipulation: + +``` +combined_index = (bit_slice_1 << 2) | bit_slice_2 +``` + +This avoids any runtime overhead from non-aligned bit extraction. + +### Offline Reordering + +The quantized weight matrix is permuted offline (at model load time) so that after +the bit-slices are loaded and dequantized, the resulting values are already in the +exact register layout that the `m16n8k16` tensor core instruction expects. + +This is possible because the quantized weights are **static** during inference — they +never change. So the permutation is computed once and applied once. At runtime, the +kernel simply loads pre-permuted data and feeds it to tensor cores without any +reordering overhead. + +The permutation accounts for: +- The thread-to-element mapping of the MMA instruction +- The shared-memory-to-register copy layout (ldmatrix) +- The bit-slice separation + +### For 4-bit Quantization + +4-bit is simpler: 32 values per 128-bit word, clean division. No bit-slice +decomposition needed. The offline restructuring still applies — weights are permuted +so that the dequantized layout matches tensor core expectations. + +### For 2-bit Quantization + +2-bit is also clean: 64 values per 128-bit word. Same approach as 4-bit. + +--- + +## 5. Vectorized Lookup Table with Duplication + +### The Problem: Shared Memory Bank Conflicts + +The lookup table for dequantization is stored in shared memory. For K-bit +quantization, the table has 2^K entries. When 32 threads in a warp each look up +a different index, the access pattern can cause bank conflicts. + +Shared memory has 32 banks, each 4 bytes wide. If two threads access different +4-byte words in the same bank, the accesses are serialized. + +For a 4-bit LUT with 16 entries of 2 bytes (half precision) each: +- Total LUT size: 32 bytes +- The 16 half values occupy banks 0-7 (2 half values per 4-byte bank) +- Threads accessing different indices in the same bank conflict + +### Vectorized Lookup + +FLUTE creates an **expanded lookup table** containing every possible pair of +consecutive indices. Instead of looking up one value at a time, it looks up two +values simultaneously. + +For 4-bit quantization: +- Original table: 2^4 = 16 entries of `half` (2 bytes each) = 32 bytes +- Vectorized table: 2^8 = 256 entries of `half2` (4 bytes each) = 1024 bytes + +The kernel extracts pairs of 4-bit indices from packed data, forms an 8-bit index, +and uses it to load a `half2` containing both dequantized values in a single shared +memory transaction. This halves the number of shared memory accesses. + +For 3-bit quantization: +- Original: 2^3 = 8 entries +- Vectorized: 2^6 = 64 entries of `half2` = 256 bytes + +### LUT Duplication + +Even with vectorization, bank conflicts can still occur. For the 4-bit vectorized +table (256 × 4 bytes = 1024 bytes), the entries map across 256 banks positions, +cycling through all 32 banks 8 times. If 8 threads in a warp happen to access +entries that map to the same bank, you get an 8-way conflict. + +FLUTE mitigates this by **duplicating** the entire vectorized table multiple times +in shared memory, placing each copy at a different base address that shifts the +bank alignment. When a thread would conflict on one copy, it can access a +different copy that maps to a different bank. + +The number of duplicates is a tuning parameter. For 4-bit with 256 entries: +- 1 copy: up to 8-way conflicts +- 2 copies: up to 4-way conflicts +- 4 copies: up to 2-way conflicts +- 8 copies: conflict-free (8 KB total — still small vs. 48-164 KB shared memory) + +For 3-bit with 64 entries: +- Vectorized table is only 256 bytes +- 2-way conflicts max, so fewer duplicates needed + +The duplication count is selected during auto-tuning (see Section 9). + +### Implementation Detail + +The dequantization in the kernel (`packbits_utils.hpp`) supports multiple modes: + +```cpp +enum QuantMapModeEnum { + Basic, // Standard per-element LUT lookup + Vectorized, // Vectorized half2 lookup (default) + Vectorized_32, // Vectorized with 32-entry table + Vectorized_16, // Vectorized with 16-entry table + Vectorized_8, // Vectorized with 8-entry table + WarpShuffle, // __shfl_sync-based lookup (registers) + Marlin // Marlin-style arithmetic dequant +}; +``` + +The `Vectorized` mode is the default and primary mode. The `WarpShuffle` mode uses +`__shfl_sync()` for in-register lookups (similar to bitsandbytes' approach). The +`Marlin` mode delegates to Marlin's `lop3`-based arithmetic dequantization for +uniform INT4. + +--- + +## 6. Stream-K Workload Partitioning + +### The Problem: Wave Quantization + +Standard GEMM kernels partition the output matrix into tiles and launch one +threadblock per tile. If the number of tiles doesn't divide evenly by the number +of SMs, the last wave has idle SMs. + +Example: 32 output tiles on 132 SMs (H100). Only 32/132 = 24% utilization. +Even with split-K to create more blocks, the granularity is coarse. + +### Stream-K Solution + +Stream-K (introduced by CUTLASS) partitions work at a finer granularity than +output tiles. Instead of assigning one complete output tile to each threadblock, +it distributes individual K-tiles across threadblocks. + +The work is linearized: all (M-tile, N-tile, K-tile) combinations are laid out +in a 1D sequence and distributed evenly across a fixed number of threadblocks +(typically = num_SMs). + +When multiple threadblocks contribute to the same output tile (because they +process different K-ranges), they synchronize via a semaphore-based fixup: + +1. Non-finishing blocks store partial accumulator values in a global workspace +2. Synchronization via `cutlass::Barrier` primitives (`wait_lt`, `wait_eq`, + `arrive_inc`) +3. The finishing block reads, reduces, and writes the final result + +### FLUTE's Stream-K Implementation + +FLUTE's `TileScheduler` (`tile_scheduler_utils.hpp`) implements both Split-K +and Stream-K modes: + +```cpp +enum DecompositionModeEnum { + SplitK, // Fixed K-split across slices + StreamK // Fine-grained K-tile distribution +}; +``` + +In Stream-K mode: +- Total tiles = `tiles_M × tiles_N × tiles_K` +- `tiles_per_block = total_tiles / num_blocks` +- `blocks_special = total_tiles % num_blocks` (these get one extra tile) + +The `FixupHelper` handles the inter-block reduction: +- `BACKWARDS` flag reverses logical block ordering so the last block coordinates +- Partial sums accumulated in FP32 for numerical stability +- Global reduction done in FP16 to minimize memory traffic + +--- + +## 7. CUTLASS 3 / CuTe Implementation + +FLUTE is built entirely on **CUTLASS 3.x** (specifically v3.4.1) using the +**CuTe** (CUDA Templates) abstraction layer. This is a significant architectural +choice that differs from hand-written CUDA kernels like Marlin. + +### CUTLASS 3.x Architecture Layers + +CUTLASS 3.x decomposes GEMM into composable layers: + +1. **Device layer**: Top-level API, manages grid launch +2. **Kernel layer**: Thread block-level orchestration +3. **Collective layer**: Multi-thread cooperation patterns (sync, pipelining) +4. **Tiled MMA/Copy**: Spatial micro-kernels for tiling +5. **Atom layer**: Hardware-specific instructions (MMA, ldmatrix, cp.async) + +FLUTE customizes the **Collective** and **Tiled Copy** layers to inject LUT +dequantization into the standard GEMM pipeline. + +### CuTe Abstractions Used + +- **Layouts**: `SmemLayoutA`, `SmemLayoutQ`, `SmemLayoutS`, etc. with 3x3x3 + swizzle patterns for bank-conflict-free shared memory access +- **TiledCopy**: Separate copy operations for A matrix (activations), Q matrix + (packed quantized weights), Q2 (second bit-slice for 3-bit), and S (scales) +- **TiledMma**: SM80_16x8x16 MMA operations for half/bfloat16 +- **Async copy**: `cp.async` for global → shared memory transfers with predication +- **Register fragments**: `FragA`, `FragB`, `FragC`, `FragS` for tensor core inputs + +### The GEMM Pipeline + +The kernel's main loop (from `qgemm_kernel.hpp`) follows this pattern: + +``` +1. PREFETCH: Load lookup table from global → shared memory (once) + +2. TILE LOOP: For each K-tile: + a. Async copy: input tile (X) from global → shared + b. Async copy: quantized weight slices (Q1, Q2, S) from global → shared + c. Wait for copies to complete + +3. FRAGMENT LOOP: For each register-backed fragment within the tile: + a. Copy fragment data from shared → registers (ldmatrix for A) + b. Load packed weight data from shared → registers + c. For 3-bit: Combine bit-slices in registers + Q_combined = combine(Q1_reg, Q2_reg) + d. Vectorized dequantization: + W_dequant = vec_dequantize(Q_combined, scale_reg, LUT_shared) + e. Tensor core MMA: + Y_reg = tensor_core_mma(Y_reg, X_reg, W_dequant) + +4. EPILOGUE: Convert FP32 accumulators → FP16, write to global memory + (with Stream-K fixup if needed) +``` + +### Multi-Stage Pipeline + +The kernel uses circular shared memory buffers with configurable pipeline depth +(`Stages` template parameter, typically 2-4). This overlaps global→shared copies +with shared→register copies and computation: + +- Stage N: Computing MMA on fragments from shared memory +- Stage N+1: Loading next tile from global to shared memory + +The number of stages is a tuning parameter (see Section 9). + +--- + +## 8. Source Code Structure + +Repository: https://github.com/HanGuo97/flute + +### CUDA/C++ Sources (`flute/csrc/`) + +| File | Purpose | +|---|---| +| `qgemm_kernel.hpp` | **Main kernel**: Template device function `qgemm_device` and host launcher `qgemm_host`. Contains the full GEMM pipeline with dequantization. | +| `config.hpp` | **Configuration**: `GemmConfig` template with all tile sizes, thread counts, shared memory layouts, MMA configurations, copy operations. | +| `packbits_utils.hpp` | **Dequantization**: `DequantizationTraits` template with specializations for 2/3/4-bit, vectorized/shuffle/Marlin modes. Core dequant logic. | +| `tile_scheduler_utils.hpp` | **Work distribution**: `TileScheduler` with Split-K and Stream-K modes. `FixupHelper` for inter-block reduction. | +| `conversion_utils.hpp` | **Type conversion**: Register-level tensor type conversion using CUTLASS converters. | +| `marlin_utils.hpp` | **Marlin compatibility**: Marlin-style `lop3`-based INT4 dequantization for uniform quantization mode. | +| `qgemm_kernel_raw_generated.cu` | **Generated instantiations**: Pre-compiled kernel variants for supported shapes/configs. | +| `qgemm_kernel_example.cu` | **Example**: Template instantiation example showing how to configure a kernel. | +| `qgemm.cpp` | **PyTorch binding**: C++ entry point that dispatches to the appropriate kernel template. | +| `hadamard_transform_cuda.cu` | **Hadamard transform**: CUDA kernel for the HadaCore integration. | +| `cutlass_extensions_bf16.h` | **BF16 extensions**: Additional bfloat16 support utilities. | + +### Python Sources (`flute/`) + +| File | Purpose | +|---|---| +| `ops.py` | PyTorch custom op registration with fake tensor implementations for torch.compile. | +| `tune.py` | Auto-tuning: benchmarks multiple kernel configurations and selects the fastest. | +| `packbits_utils.py` | Weight packing: `to_binary`, `from_binary`, `pack_bools_into_integers`, `pack_integer_tensors`. | +| `nf_utils.py` | NormalFloat codebook generation via inverse Gaussian CDF. Quantization/dequantization. | +| `utils.py` | General utilities. | +| `codegen_utils.py` | Code generation helpers for kernel instantiation. | + +### Key Configuration Parameters (`config.hpp`) + +The `GemmConfig` template is parameterized by: + +``` +Data types: + T — compute type (half, bfloat16) + TQ — quantized weight type (int16) + TC — accumulation type (float) + TR — reduction type + +Threading: + Warps — number of warps per block + Threads — total threads (must be multiple of 128) + +Quantization: + NumBits — 2, 3, or 4 + GroupSize — 32, 64, 128, or 256 + NumPacked — number of packed elements per int16 + +Tiling: + TileM, TileK, TileP — tile dimensions for M, K, packed-weight axes + Stages — pipeline depth (2-4) + StagesG — pipeline stages for scale loading + +Copy operations: + G2SCopySizeA, G2SCopySizeQ, etc. — transfer granularity + +MMA configuration: + MmaThrM, MmaThrN, MmaThrK — thread layout within MMA + MmaPrmM, MmaPrmN, MmaPrmK — permutation within MMA +``` + +--- + +## 9. Kernel Configuration and Tuning + +FLUTE is **shape-specialized** — for each combination of (M, N, K, num_bits, +group_size, dtype, GPU), a specific kernel configuration is selected via benchmarking. + +### What Gets Tuned + +The `template_id` parameter encodes a specific combination of: +- Tile sizes (TileM, TileN, TileK) +- Pipeline stages +- Number of LUT duplicates (for bank conflict mitigation) +- Thread block configuration +- MMA layout + +### Tuning Process + +From `tune.py`: + +1. For a given matrix shape and quantization config, enumerate candidate + `template_id` values +2. For each candidate, run the kernel at least 100 times +3. Measure average execution time +4. Select the fastest `template_id` +5. Cache the result for future use + +The tuned `template_id` is stored in the model's metadata and passed to `qgemm()` +at inference time. + +### Correctness Verification + +After tuning, the framework runs correctness checks: +- Generates test cases with known-good outputs +- Compares against thresholds: FP16 ≤ 2.0e-3, BF16 ≤ 1.1e-2 + +### Limitations + +- Each new model shape requires re-tuning +- Different tensor parallel configurations create different shapes +- The team is working on JIT tuning to reduce this constraint +- As of January 2025, experimental auto-tune support removes some shape/GPU + specialization + +--- + +## 10. NormalFloat and NFL (Learned NormalFloat) + +### NormalFloat (NF) Codebook + +The standard NF codebook (same concept as QLoRA's NF4) generates quantization +levels from the inverse Gaussian CDF: + +1. Generate 2^(b-1) evenly-spaced probability values in [δ, 1/2] and [1/2, 1-δ] + where δ = 1/2 × (1/30 + 1/32) +2. Convert to quantiles via inverse CDF: q_i = Φ^(-1)(p_i) +3. Normalize: q̃_i = q_i / q_{2^b - 1} + +The result is a symmetric codebook in [-1, 1] optimized for normally-distributed +weights. + +### Group-Level Scaling + +For a weight group u with absmax s = max(|u|): +- Quantize: c_j = argmin_i |q̃_i - u_j/s| +- Dequantize: T[Q_{ij}] × s_{(i×j) mod B} + +### NFL (Learned NormalFloat) + +NFL extends NF by learning the scale parameter σ̃: + +1. Reformulate quantization: c_j = argmin_i |sσ̃q_i - u_j| +2. Initialize σ̃ from the standard NF normalization constant: σ̃ = 1/Φ^(-1)(1-δ) +3. Optimize σ̃ via gradient descent on negative log-likelihood +4. Use calibration data: 128 examples × 2048 tokens from WikiText-2 +5. Apply straight-through estimator for the argmin gradient +6. Save the learned scale as sσ̃/σ (preserves dequantization format) + +This adds minimal overhead (learning one scalar per group) but measurably improves +quantization quality. + +### Results + +LLaMA-3.1 8B with NFL W4G64: +- WikiText-2 perplexity: 6.24 (vs 6.31 unquantized — actually better due to + the calibration fitting) + +LLaMA-3.1 70B with NFL W4G64: +- WikiText-2 perplexity: 3.09 (vs 2.82 unquantized) + +--- + +## 11. Performance Analysis + +### Kernel-Level Benchmarks + +**4-bit quantization, group size 128:** +- 2-4× speedup over FP16 `torch.mm` at batch < 32 +- Outperforms bitsandbytes and BitBLAS-NF4 LUT kernels +- Competitive with uniform-quantization kernels (Marlin, BitBLAS-INT4) +- At batch sizes > 32, advantage diminishes (GEMM becomes compute-bound) + +**3-bit quantization:** +- Supported where most other LUT kernels don't support it at all +- Consistent speedups across group sizes 32, 64, 128, 256 + +### End-to-End LLM Throughput + +**LLaMA-3 8B** (batch=1, single GPU): +- 4-bit, group=128: ~2.2× tokens/s improvement, perplexity 6.2 +- 3-bit, group=128: ~2.4× tokens/s improvement, perplexity 4.6 + +**LLaMA-3 70B** (tensor parallelism): +- 4-bit, group=256: ~1.9-2.0× improvement (4×A6000, 2×A100) +- 3-bit, group=256: ~1.7-2.0× improvement (4×A6000, 2×A100) + +**LLaMA-3.1 405B**: Enables single-node inference (impossible without +quantization) + +### Hardware-Specific Performance + +Optimized for **Ampere GPUs** (A100, A6000, RTX 4090). Not yet optimized for +Hopper (H100), though it runs. bfloat16 is slower than float16, likely due to +lack of Ampere hardware-accelerated bfloat16 atomic-add. + +--- + +## 12. Comparison with Other Kernels + +### FLUTE vs. Marlin + +| Aspect | FLUTE | Marlin | +|---|---|---| +| **Quantization type** | LUT-based (arbitrary codebooks) | Uniform (INT4/INT8 linear) | +| **Bit widths** | 2, 3, 4 | 4, 8 | +| **Dequant method** | Shared memory LUT lookup | `lop3` bit manipulation in registers | +| **Work distribution** | Stream-K (CUTLASS) | Custom stripe partitioning | +| **Implementation** | CUTLASS 3 / CuTe templates | Hand-written CUDA | +| **Weight format** | Offline-restructured, bit-sliced | Custom tiled INT4 packing | +| **Bank conflict handling** | LUT duplication + vectorization | N/A (arithmetic dequant) | +| **Target GPU** | Ampere (SM80) | Ampere + Hopper | +| **Performance (4-bit)** | Competitive at batch < 32 | Slightly faster at small batch | +| **3-bit support** | Yes | No | +| **Codebook flexibility** | Arbitrary | Linear only | + +Key insight: Marlin uses register-level arithmetic for dequantization (no memory +access), while FLUTE uses shared memory lookup. For uniform quantization, Marlin's +approach is faster. For non-uniform/codebook quantization, FLUTE's approach is +necessary. + +FLUTE also includes a `Marlin` mode in its `QuantMapModeEnum` that delegates to +Marlin-style `lop3` dequantization for the uniform INT4 case. + +### FLUTE vs. bitsandbytes (Current) + +| Aspect | FLUTE | bitsandbytes | +|---|---|---| +| **Approach** | Fused dequant+GEMM | Separate dequant, then cuBLAS | +| **Tensor cores** | Yes (via CUTLASS MMA) | No (dequant only, cuBLAS for GEMM) | +| **LUT mechanism** | Vectorized shared memory | `__shfl_sync` in registers | +| **Bit widths** | 2, 3, 4 | 2, 3, 4, 5 (kbit branch) | +| **Performance** | 2-4× over dequant+cuBLAS | Baseline (dequant+cuBLAS) | + +### FLUTE vs. Proposed kbit GEMM (from kbit_gemm_context.md) + +| Aspect | FLUTE | Proposed kbit GEMM | +|---|---|---| +| **Framework** | CUTLASS 3 / CuTe | Hand-written CUDA | +| **LUT storage** | Shared memory (vectorized+duplicated) | Registers (`__shfl_sync`) | +| **Work distribution** | Stream-K (CUTLASS built-in) | Persistent kernel with split-K | +| **Bit widths** | 2, 3, 4 | 2, 3, 4, 5 | +| **Weight format** | Bit-slice decomposed, offline restructured | Bit-plane (from `__ballot_sync`), tiled | +| **Scale format** | FP16 group scales | E4M4 absmax (1 byte per block of 32) | +| **Block size** | Configurable (32, 64, 128, 256) | Fixed at 32 | +| **Target GPU** | Ampere | Ampere + Hopper | + +--- + +## 13. Detailed Comparison: FLUTE vs. Bitsandbytes kbit + +This section provides a side-by-side analysis of every major design decision, +referencing the actual bitsandbytes kbit implementation on the +`feature/kbit-quantization` branch (`csrc/ops.cu` lines 649-869) and the planned +GEMM kernel design from `agents/kbit_gemm_context.md`. + +### 13.1 Codebook Lookup Mechanism + +This is the single biggest architectural difference between the two kernels. + +**FLUTE: Vectorized shared memory LUT with duplication** + +FLUTE stores the lookup table in shared memory. To reduce the number of shared +memory transactions, it creates a "vectorized" table containing every possible +*pair* of consecutive indices. For 4-bit quantization: + +- Original table: 16 entries × 2 bytes (half) = 32 bytes +- Vectorized table: 256 entries × 4 bytes (half2) = 1024 bytes + +The kernel extracts pairs of 4-bit indices from packed weight data, forms an +8-bit combined index, and fetches a `half2` from shared memory in one transaction. +This halves the number of shared memory reads. + +To handle bank conflicts (up to 8-way for 4-bit), FLUTE duplicates the entire +vectorized table multiple times in shared memory at different base addresses, +shifting bank alignment. The duplication count is auto-tuned per shape/GPU. +Worst case: 8 copies × 1 KB = 8 KB of shared memory for the table alone. + +Modes in `packbits_utils.hpp`: +```cpp +enum QuantMapModeEnum { + Basic, // Per-element LUT lookup + Vectorized, // Vectorized half2 lookup (default) + WarpShuffle, // __shfl_sync-based (register) + Marlin // lop3 arithmetic dequant +}; +``` + +**kbit: Register shuffle via `__shfl_sync`** + +The bitsandbytes kbit kernel stores the codebook in a single register per lane: + +```cpp +// ops.cu line ~766 (standalone dequant), GEMM plan uses same pattern: +float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0.0f; +// ... +float val = __shfl_sync(0xFFFFFFFF, cb, idx) * amax; +``` + +For the GEMM kernel, the codebook is pre-converted to half at kernel start: +```cpp +half cb_h = (lane < (1 << K_BITS)) + ? __float2half(codebook[lane]) : __float2half(0.0f); +// In inner loop: +half val = __shfl_sync(0xFFFFFFFF, cb_h, idx); +``` + +Each lane holds one codebook entry in a register. Lookup is a warp shuffle with +arbitrary per-thread source lane selection. Cost: 1 cycle on the shuffle unit, +zero memory bandwidth consumed. + +**Why kbit's approach is better for our use case:** + +- Our codebooks have at most 2^5 = 32 entries (K=2..5), fitting exactly in a + 32-lane warp. No shared memory needed at all. +- Shuffle is 1 cycle with zero bank conflicts by definition. +- No shared memory space consumed by the table — more room for A and B tiles. +- No duplication/tuning complexity. +- The shuffle approach is already proven in the existing standalone dequant + kernel (`ops.cu` line 783). + +FLUTE needs shared memory because it's designed to be generic — it supports +arbitrary table sizes that could exceed 32 entries. For exactly this reason, +FLUTE also offers a `WarpShuffle` mode, but it isn't the default. + +### 13.2 Weight Packing Format + +**FLUTE: Contiguous K-bit packing with bit-slice decomposition** + +FLUTE packs quantized indices contiguously. For 4-bit: two 4-bit indices per +`uint8`, or 8 per `uint32`. The packed `int16` values are loaded via 128-bit +async copies. + +For 3-bit (which doesn't divide evenly into 128-bit words), FLUTE uses +**bit-slice decomposition**: split each 3-bit index into a 1-bit MSB and a +2-bit LSB, store them in separate arrays, load each with clean 128-bit copies, +and combine in registers: + +``` +combined_index = (bit_slice_1 << 2) | bit_slice_2 +``` + +The offline restructuring permutes packed weights so that after loading and +dequantization, values land in the exact register positions that `m16n8k16` +tensor cores expect. This means the kernel never does runtime reordering. + +**kbit: Bit-plane format via `__ballot_sync`** + +The bitsandbytes quantize kernel (`ops.cu` line 706) produces K separate +`uint32` bit-plane words per block of 32 elements: + +```cpp +// pack_kbit_warp: +for (int bit = 0; bit < K; bit++) + packed_words[bit] = __ballot_sync(0xFFFFFFFF, (qval >> bit) & 1); +``` + +Bit-plane 0 contains bit 0 of all 32 elements, bit-plane 1 contains bit 1, etc. +The GEMM repack kernel retiles this from flat sequential into +`[k_tile][n_tile][col][k_block][bit_plane]` order for coalesced tile loads. + +To extract an index in the GEMM kernel: +```cpp +for (int b = 0; b < K_BITS; b++) + idx |= ((planes[b] >> row) & 1) << b; +``` + +**Comparison:** + +| Aspect | FLUTE | kbit | +|---|---|---| +| Storage unit | Contiguous K-bit fields in int16 | K separate uint32 bit-plane words | +| 3-bit handling | Bit-slice split (1+2), two separate loads | Natural: K=3 bit-planes, same as K=2,4,5 | +| 5-bit handling | Not supported | Natural: K=5 bit-planes | +| Extraction cost | Shift+mask to isolate K-bit field from packed word | K shift+mask+OR to assemble index from planes | +| Memory footprint | K bits per element | K bits per element (identical) | +| Runtime reordering | None (offline permutation matches tensor core layout) | None (repack kernel produces tile-aligned layout) | + +The bit-plane format's key advantage is uniformity: K=2,3,4,5 all work +identically with no special cases. FLUTE needs separate code paths for 3-bit +(the bit-slice decomposition). The bit-plane extraction cost (K INT32 ops per +element) runs on integer ALU concurrent with tensor core MMA, so it's +effectively hidden. + +### 13.3 Scale/Absmax Format and Application + +**FLUTE: FP16 group scales** + +FLUTE uses standard half-precision scales with configurable group sizes +(32, 64, 128, 256). Dequantization is: `value = table[index] * scale`. + +The scales are loaded from global → shared memory alongside the packed weights, +with their own pipeline stage (`StagesG`). Inside the fragment loop, scale values +are applied via `__hmul2()` paired half multiplication. + +Storage overhead per element: 2 bytes / group_size. For group_size=128: 0.0156 +bytes/element. For group_size=32: 0.0625 bytes/element. + +**kbit: E4M4 absmax (1 byte per block of 32)** + +The kbit system uses a custom 8-bit floating point format for the per-block +absmax value (`ops.cu` line 722): + +```cpp +// E4M4: 4-bit exponent (bias=11) + 4-bit mantissa +// Normal: 2^(e-11) * (1 + m/16), range ~[6.1e-5, 31.0] +// Decode: construct IEEE 754 float via bit manipulation +unsigned int ieee = (unsigned int)(e - E4M4_BIAS + 127) << 23 + | (unsigned int)m << 19; +return __uint_as_float(ieee); +``` + +Dequantization is: `value = codebook[index] * absmax`. The absmax is always +per-block (blocksize=32), giving fine-grained scaling. + +Storage overhead: 1 byte / 32 = 0.03125 bytes/element. This is: +- 2× less than FLUTE with group_size=32 (0.0625 bytes/element) +- Same as FLUTE with group_size=64 in absolute bytes, but kbit gets + per-32-element granularity vs FLUTE's per-64-element granularity +- Max relative error from E4M4: 6.25% (1/16 from 4-bit mantissa) + +In the GEMM kernel, absmax decode happens once per block-of-32 per column per +K-tile (256 decodes total for TILE_N=128, TILE_K=64). The decode is ~5 integer +ALU ops, negligible compared to MMA throughput. + +**Why E4M4 matters:** + +At K=2 (2-bit quantization), each element is 2 bits = 0.25 bytes. FLUTE's FP16 +scale at group_size=128 adds 0.0156 bytes/element (6.25% overhead). kbit's E4M4 +at blocksize=32 adds 0.03125 bytes/element (12.5% overhead) but with 4× finer +granularity — and in 1 byte instead of 2. The finer granularity typically +improves quantization quality more than the coarser group hurts it. + +### 13.4 Work Distribution and Split-K + +**FLUTE: Stream-K via CUTLASS** + +FLUTE uses CUTLASS's built-in Stream-K decomposition (`tile_scheduler_utils.hpp`). +All (M,N,K) tiles are linearized into a 1D work sequence and distributed evenly +across `num_blocks` threadblocks: + +```cpp +tiles_per_block = total_tiles / num_blocks; +blocks_special = total_tiles % num_blocks; // get +1 tile +``` + +When multiple blocks contribute to the same output tile (different K-ranges), +the `FixupHelper` coordinates via `cutlass::Barrier` primitives. Partial sums +are stored in FP32 in a global workspace; the finishing block reduces and +converts to FP16. + +Grid launch: `dim3(num_blocks)` for Stream-K mode. + +**kbit: Persistent kernel with linearized work assignment** + +The kbit GEMM plan launches exactly `num_SMs` blocks. Work items are linearized +as (m_tile, n_tile, k_chunk) triples, ordered so that all k_chunks for a given +(m,n) output tile are contiguous: + +```cpp +int work_per_block = div_ceil(total_work, gridDim.x); +int my_start = blockIdx.x * work_per_block; +int my_end = min(my_start + work_per_block, total_work); +``` + +Key optimization: when consecutive work items share the same output tile, the +block keeps accumulators in registers across k_chunks — no intermediate write. +The pipeline restarts between chunks (~2-tile cost), but accumulators persist. + +Output write uses a three-way branch: +- Full K-range ownership → write FP16 directly (common case for large M) +- First contributor → write FP32 to workspace (overwrite, acts as zero+write) +- Subsequent contributors → atomicAdd FP32 to workspace + +A per-tile atomic counter tracks when the last contributor finishes, which +then converts FP32 → FP16 in the final output. + +**Comparison:** + +| Aspect | FLUTE (Stream-K) | kbit (Persistent) | +|---|---|---| +| Implementation | CUTLASS built-in | Hand-written | +| Launch config | `dim3(num_blocks)` | `dim3(num_SMs)` | +| Granularity | Per K-tile | Per k_chunk (multiple K-tiles) | +| Sync mechanism | `cutlass::Barrier` semaphores | `atomicAdd` + atomic counter | +| Accumulator reuse | Each block handles isolated work items | Consecutive same-(m,n) items share accumulators | +| Reduction | Finishing block reduces all partials | Last contributor (via counter) converts to FP16 | +| Dependency | Requires CUTLASS | Self-contained | + +The persistent kernel's accumulator-reuse optimization is significant: for +problems where each block handles multiple k_chunks for the same output tile, +it avoids writing and re-reading intermediate FP32 partials. Stream-K doesn't +have this optimization — each block writes its partial to global memory. + +### 13.5 Bit-Width Support + +| Bits | FLUTE | kbit | +|---|---|---| +| 2-bit | Yes (build from source) | Yes | +| 3-bit | Yes (bit-slice decomposition) | Yes (bit-plane, no special case) | +| 4-bit | Yes (primary target) | Yes | +| 5-bit | No | Yes | + +FLUTE's lack of 5-bit support is likely because the bit-slice approach would +need a 2+3 or 1+4 split, adding another code path. The kbit bit-plane format +handles K=5 identically to K=2,3,4. + +### 13.6 Implementation Framework + +**FLUTE: CUTLASS 3 / CuTe templates** + +- All tiling, pipelining, and MMA via CUTLASS abstractions +- Shared memory layouts use CuTe's swizzle patterns (3×3×3) +- Async copies via `cp.async` managed by CUTLASS pipeline stages +- `TiledCopy` and `TiledMma` handle thread-to-data mapping +- `GemmConfig` template encodes the full kernel configuration +- Code generation produces template instantiations per (shape, bits, GPU) + +Pros: Less custom infrastructure to write, well-tested pipeline/sync code. +Cons: Massive template expansion, slow compile, CUTLASS version dependency +(pinned to v3.4.1), shape-specialized binaries. + +**kbit: Hand-written CUDA** + +- Custom tiling with explicit loop structures +- Manual `cp.async` pipeline (2-stage double buffer) +- Inline PTX for `ldmatrix` and `mma.sync` instructions +- No external dependencies beyond CUDA toolkit +- Single compilation unit (`kernels.cu`) with template params `` +- Kernel config selected at launch time based on M dimension + +Pros: Full control over register allocation and scheduling, no dependency +management, single binary works for all shapes of the same (K, M_BLOCKS). +Cons: Must implement all infrastructure manually, more potential for bugs in +pipeline/sync code. + +### 13.7 Tensor Core Usage + +Both kernels use the same fundamental MMA instruction: `m16n8k16` with FP16 +inputs and FP32 accumulation. + +**FLUTE**: CuTe's `SM80_16x8x16_F32F16F16F32` atom, configured via `TiledMma` +with customizable thread layout (`MmaThrM × MmaThrN × MmaThrK`) and +permutation (`MmaPrmM × MmaPrmN × MmaPrmK`). + +**kbit**: Direct inline PTX `mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32` +instruction. Thread-to-fragment mapping hand-computed: +- 4 threads per column (lane/4 = column index) +- Row indices: {2i, 2i+1, 2i+8, 2i+9} where i = lane%4 +- FragA: M_BLOCKS × half2[2] per k-sub-tile +- FragB: half2[2] per N-block (dequantized on the fly, not stored) + +The kbit design explicitly exploits the 4-threads-per-column property for +shared memory access: when loading bit-plane words, 4 threads read the same +K addresses, getting a free 4-way broadcast with zero bank conflicts. FLUTE +doesn't need this optimization because its offline restructuring already +places data in the correct register positions. + +### 13.8 Pipeline Design + +**FLUTE**: Configurable multi-stage pipeline (2-4 stages, auto-tuned). +Separate pipeline stages for different data streams: +- `Stages`: Main pipeline depth for A and Q tiles +- `StagesG`: Separate depth for scale factor loading +- `StagesGView`: View stages for handling GroupSize/TileK relationships + +Circular shared memory buffers managed by CUTLASS pipeline abstractions. + +**kbit**: 2-stage double-buffered pipeline (fixed). +- Stage 0 and Stage 1 alternate in shared memory +- `cp_async_fence()` and `cp_async_wait<1>()` for synchronization +- Pipeline restarts when switching k_chunks (2-tile cost) + +The kbit approach is simpler but less flexible. FLUTE's ability to tune the +pipeline depth per shape can yield better performance in specific cases. + +### 13.9 Offline Weight Preparation + +Both require offline weight restructuring, but the details differ. + +**FLUTE offline restructuring:** + +1. Quantize weights to K-bit indices using a codebook (NF or custom) +2. Pack indices contiguously (for 3-bit: split into 1+2 bit-slices) +3. **Permute** packed words so that after loading and dequantization, values + land directly in tensor core register positions +4. The permutation encodes: thread-to-element MMA mapping + ldmatrix layout + + bit-slice separation + +This is a single combined permutation that folds multiple concerns together. + +**kbit offline restructuring:** + +1. Quantize weights via `kQuantizeBlockwise_kbit` → flat bit-plane format + (K uint32 words per block of 32 elements, sequential) +2. Encode absmax from float32 to E4M4 uint8 +3. **Retile** bit-planes from flat → `[k_tile][n_tile][col][k_block][bit_plane]` +4. **Retile** absmax from flat → `[k_tile][n_tile][col][k_block]` + +The kbit repack is a simpler gather/permutation — it only changes the tile +layout, not the data format within tiles. No MMA-layout-aware permutation is +needed because the GEMM kernel handles the thread-to-element mapping at runtime +via the bit-plane extraction + `__shfl_sync` codebook lookup. + +### 13.10 Summary: When to Prefer Which Approach + +**FLUTE is better when:** +- You need arbitrary codebook sizes (> 32 entries) +- You want to leverage CUTLASS's tested infrastructure +- You need auto-tuning across many different matrix shapes +- You need Stream-K's sophisticated edge-case handling +- 3-bit and 4-bit are the primary targets + +**kbit is better when:** +- Codebooks are ≤ 32 entries (K ≤ 5) — register shuffle is strictly faster +- You need 5-bit support +- You want zero external dependencies +- Fine-grained E4M4 absmax (per-32-element) is important +- You need a single binary that works across all shapes (no re-tuning) +- You want Hopper GPU support from the start +- The bit-plane format naturally handles all K values uniformly + +--- + +## 14. Limitations and Known Issues + +1. **Shape specialization**: Each matrix shape requires separate tuning and + compilation. Different tensor parallel configurations create different shapes, + limiting supported models. (Partial mitigation via auto-tune as of Jan 2025.) + +2. **Ampere-only optimization**: Not yet leveraging Hopper features (TMA, warp + specialization, distributed shared memory). Runs on H100 but not at peak. + +3. **bfloat16 performance**: Slower than float16 on Ampere due to lack of + hardware-accelerated bfloat16 atomic-add (needed for Stream-K reduction). + +4. **Large batch degradation**: Performance advantage diminishes at batch > 32 + as the GEMM becomes compute-bound rather than memory-bandwidth-bound. + +5. **Numerical issues**: Some instability reported with 4-bit, group-size=256 + on A100. + +6. **No 5-bit support**: FLUTE supports 2, 3, 4-bit only. The kbit design + supports 5-bit as well. + +--- + +## 15. Links and References + +### Primary Sources + +- **Paper (ArXiv)**: https://arxiv.org/abs/2407.10960 +- **Paper (PDF)**: https://arxiv.org/pdf/2407.10960 +- **Paper (HTML)**: https://arxiv.org/html/2407.10960 +- **Paper (ACL Anthology)**: https://aclanthology.org/2024.findings-emnlp.724/ +- **GitHub Repository**: https://github.com/HanGuo97/flute +- **HuggingFace Paper Page**: https://huggingface.co/papers/2407.10960 + +### Source Code (Key Files) + +- **Main kernel**: https://github.com/HanGuo97/flute/blob/main/flute/csrc/qgemm_kernel.hpp +- **Configuration**: https://github.com/HanGuo97/flute/blob/main/flute/csrc/config.hpp +- **Dequantization**: https://github.com/HanGuo97/flute/blob/main/flute/csrc/packbits_utils.hpp +- **Tile scheduling**: https://github.com/HanGuo97/flute/blob/main/flute/csrc/tile_scheduler_utils.hpp +- **Weight packing**: https://github.com/HanGuo97/flute/blob/main/flute/packbits_utils.py +- **NF utilities**: https://github.com/HanGuo97/flute/blob/main/flute/nf_utils.py +- **Auto-tuning**: https://github.com/HanGuo97/flute/blob/main/flute/tune.py +- **Ops/dispatch**: https://github.com/HanGuo97/flute/blob/main/flute/ops.py + +### Pre-Quantized Models + +- **HuggingFace Hub**: Models under the `HanGuo97` organization + - LLaMA-3.1: 8B, 70B, 405B (base + instruct, NFL W4G64 default) + - LLaMA-3: 8B, 70B + - Gemma-2: 9B, 27B + +### Related Projects + +- **CUTLASS 3.x**: https://github.com/NVIDIA/cutlass (required dependency, v3.4.1) +- **HIGGS**: Vector dequantization extension, NAACL 2025 +- **HadaCore**: Hadamard transform integration +- **Marlin**: https://github.com/IST-DASLab/marlin (comparison kernel for uniform INT4) +- **LUT-GEMM**: Earlier work on lookup-table-based GEMM kernels +- **LUT Tensor Core (arxiv 2408.06003)**: Hardware/software co-design for LUT operations + +### Blog Posts and Analysis + +- **MarkTechPost**: https://www.marktechpost.com/2024/07/26/flute-a-cuda-kernel-designed-for-fused-quantized-matrix-multiplications-to-accelerate-llm-inference/ +- **Semantic Scholar**: https://www.semanticscholar.org/paper/Fast-Matrix-Multiplications-for-Lookup-LLMs-Guo-Brandon/be66705b36912679ea373184aaf057aa365d292a +- **AlphaXiv Discussion**: https://www.alphaxiv.org/abs/2407.10960 + +### Installation + +```bash +# Default (CUDA 12.1) +pip install flute-kernel + +# CUDA 11.8 +pip install flute-kernel -i https://flute-ai.github.io/whl/cu118 + +# CUDA 12.4 +pip install flute-kernel -i https://flute-ai.github.io/whl/cu124 + +# From source (required for 2-bit) +git clone https://github.com/HanGuo97/flute.git +cd flute +pip install -e . +``` + +### Citation + +```bibtex +@inproceedings{guo2024flute, + title={Fast Matrix Multiplications for Lookup Table-Quantized LLMs}, + author={Guo, Han and Brandon, William and Cholakov, Radostin and + Ragan-Kelley, Jonathan and Xing, Eric P. and Kim, Yoon}, + booktitle={Findings of EMNLP}, + year={2024} +} +``` diff --git a/agents/kbit_gemm_context.md b/agents/kbit_gemm_context.md new file mode 100644 index 000000000..45d68c9a0 --- /dev/null +++ b/agents/kbit_gemm_context.md @@ -0,0 +1,1391 @@ +# kbit GEMM Kernel: Complete Design Context + +This document captures the full design analysis for implementing a fused kbit +dequantization + GEMM kernel in bitsandbytes. It covers the existing kbit +quantization implementation, the Marlin kernel architecture (as reference), and +the complete design for the new GEMM kernel. A developer reading this should +be able to implement the kernel without additional context. + +--- + +## Table of Contents + +1. [Existing kbit Implementation](#1-existing-kbit-implementation) +2. [Marlin Kernel Architecture (Reference)](#2-marlin-kernel-architecture-reference) +3. [GEMM Kernel Design](#3-gemm-kernel-design) +4. [Weight Storage Format and Repacking](#4-weight-storage-format-and-repacking) +5. [Inner Loop: Dequantization + MMA](#5-inner-loop-dequantization--mma) +6. [Persistent Kernel and Work Distribution](#6-persistent-kernel-and-work-distribution) +7. [Pipeline and Shared Memory](#7-pipeline-and-shared-memory) +8. [Codebook and Absmax Handling](#8-codebook-and-absmax-handling) +9. [Performance Analysis](#9-performance-analysis) +10. [Kernel Dispatch and Python Integration](#10-kernel-dispatch-and-python-integration) +11. [File Organization and Build](#11-file-organization-and-build) +12. [Error Budget](#12-error-budget) +13. [Template Instantiations](#13-template-instantiations) +14. [Future Considerations](#14-future-considerations) + +--- + +## 1. Existing kbit Implementation + +### 1.1 Overview + +The kbit quantization system lives on the `feature/kbit-quantization` branch. +It implements K-bit blockwise quantization for K=2,3,4,5 with blocksize=32 +(one warp = one quantization block). It uses a codebook-based approach where +each element is mapped to the nearest entry in a 2^K-entry codebook, then +packed into K bit-plane words using warp-level CUDA primitives. + +Currently, only standalone quantize and dequantize kernels exist. There is no +fused GEMM. The goal of this design is to add a fused dequant+GEMM kernel that +achieves high tensor core utilization at larger batch sizes. + +### 1.2 Codebook + +The codebook is generated by `create_normal_float_codebook(k)` in +`bitsandbytes/functional.py`. It places 2^K reconstruction levels at the +expected values of N(0,1) within 2^K equiprobable bins, then normalizes to +[-1, 1]. The codebook is: + +- Sorted ascending +- Roughly symmetric around 0 +- Normalized so `abs(max) == 1.0` +- Cached per (k, device) pair + +For K=4, this is conceptually similar to the existing NF4 datatype, though with +minor numerical differences (the existing NF4 has an asymmetric zero trick). + +The codebook is always stored as float32 and passed to CUDA kernels as +`const float*`. For the GEMM kernel, it will be converted to half precision +at kernel startup (see Section 8.1). + +### 1.3 Quantize Kernel + +Location: `csrc/ops.cu`, function `kQuantizeBlockwise_kbit` (line 682). + +``` +Template parameters: + T: input type (half, __nv_bfloat16, float) + K: bit width (2, 3, 4, 5) + +Launch config: + Block size: 256 threads (KBIT_THREADS_PER_BLOCK) + Grid: ceil(num_blocks / 8) where num_blocks = ceil(n / 32) + Each CUDA block has 8 warps, each warp processes one quantization block. + +Algorithm per warp: + 1. Each lane loads one element from A (lane_id maps 1:1 to element position) + 2. Convert to float + 3. Warp-reduce absmax via __shfl_down_sync butterfly reduction + 4. Lane 0 broadcasts absmax to all lanes via __shfl_sync + 5. Lane 0 writes absmax[warp_id] + 6. Normalize: val / max(absmax, 1e-8) + 7. Load codebook into lane registers: cb = codebook[lane_id] for lane < 2^K + 8. Brute-force nearest-neighbor search: + - Loop i = 0..2^K-1 + - Broadcast codebook[i] to all lanes via __shfl_sync(cb, i) + - Compare distance, track best index + 9. Pack via __ballot_sync: for each bit b in 0..K-1, + packed[b] = __ballot_sync(0xFFFFFFFF, (best_idx >> b) & 1) + This produces K uint32 words where word b contains bit b of all 32 lanes. + 10. Lanes 0..K-1 write their respective bit-plane word to + packed_out[warp_id * K + lane_id] +``` + +Key observations: +- The output is in "bit-plane" format: K uint32 words per block of 32 elements +- `__ballot_sync` collects one bit from all 32 lanes into a single uint32 +- The packed data layout in memory is sequential: block 0's K words, then + block 1's K words, etc. +- absmax is stored as float32 (later encoded to E4M4 on the Python side) + +### 1.4 Dequantize Kernel + +Location: `csrc/ops.cu`, function `kDequantizeBlockwise_kbit_vec` (line 753). + +``` +Template parameters: + T: output type (half, __nv_bfloat16, float) + K: bit width (2, 3, 4, 5) + BLOCKS_PER_WARP: number of quantization blocks processed per warp iteration (4) + ABSMAX_T: absmax storage type (unsigned char for E4M4, half for fp16) + +Launch config: + Block size: 256 threads (8 warps) + Grid: ceil(num_warps / 8) where num_warps = ceil(num_blocks / BLOCKS_PER_WARP) + +Algorithm per warp: + 1. Load codebook into lane registers (once, amortized across BLOCKS_PER_WARP): + float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0.0f; + + 2. For each of BLOCKS_PER_WARP=4 blocks: + a. Load absmax via load_absmax(absmax, block_id) + - For unsigned char: calls decode_e4m4_absmax() + - For half: simple cast to float + b. Load K bit-plane words using shuffle broadcast: + for (bit = 0; bit < K; bit++) { + unsigned int word = (lane_id == bit) ? packed_in[block_id * K + bit] : 0; + packed[bit] = __shfl_sync(0xFFFFFFFF, word, bit); + } + Only lane `bit` reads from global memory; all other lanes receive + the value via shuffle broadcast. This minimizes global memory + transactions (K reads per block instead of K*32). + c. Unpack index: for each bit, extract that bit from the plane word + at the current lane's position, OR them together: + idx = 0; + for (bit = 0; bit < K; bit++) + idx |= ((packed[bit] >> lane_id) & 1) << bit; + d. Codebook lookup via shuffle: + float val = __shfl_sync(0xFFFFFFFF, cb, idx) * amax; + e. Write output: out[block_start + lane_id] = (T)val; +``` + +Key observations: +- The shuffle-based bit-plane loading pattern (step 2b) exploits the fact that + each lane has a 1:1 correspondence with an element position. Only K lanes + do global loads; the rest get data via shuffle. This is specific to the + standalone dequant where threads map 1:1 to elements. +- In the GEMM kernel, this pattern CANNOT be used directly because threads + are organized around tensor core fragment positions, not element positions. + Instead, bit-plane words will be loaded into shared memory by the async + pipeline, and each thread reads from shared memory for its specific column. + This is discussed in detail in Section 5. +- BLOCKS_PER_WARP=4 amortizes the codebook register load across 4 blocks. + In the GEMM kernel, the codebook is loaded once at kernel start and lives + in a register for the entire kernel lifetime -- even better amortization. + +### 1.5 E4M4 Absmax Format + +Location: `csrc/ops.cu`, function `decode_e4m4_absmax` (line 722). + +Format: 4-bit exponent + 4-bit mantissa with bias=11. +- Normal (e > 0): `2^(e - 11) * (1 + m/16)` +- Subnormal (e = 0): `2^(1 - 11) * (m/16)` = `2^(-10) * (m/16)` +- Zero (e = 0, m = 0): 0.0 + +Range: approximately [6.1e-5, 31.0] for normal values. +Max relative error: 1/16 = 6.25% (from the 4-bit mantissa). + +The decode implementation constructs an IEEE 754 float directly via bit +manipulation, avoiding any floating-point arithmetic: + +```cpp +__device__ __forceinline__ float decode_e4m4_absmax(unsigned char raw) { + if (raw == 0) return 0.0f; + int e = raw >> 4; + int m = raw & 0xF; + if (e == 0) { + return ldexpf((float)m, 1 - E4M4_BIAS - 4); // subnormal + } + unsigned int ieee = (unsigned int)(e - E4M4_BIAS + 127) << 23 + | (unsigned int)m << 19; + return __uint_as_float(ieee); +} +``` + +Cost: 1 comparison, 2 shifts, 1 OR, 1 add, 1 reinterpret. ~5 integer ALU ops. +The subnormal path uses `ldexpf` but is rarely taken in practice. + +The Python-side encoding is in `bitsandbytes/functional.py`: +`encode_absmax_e4m4()` and `decode_absmax_e4m4()`. + +Storage savings: 1 byte per block of 32 elements vs 4 bytes for float32. +This reduces absmax overhead from 0.125 bytes/element to 0.03125 bytes/element. + +### 1.6 Bit-Plane Packing Helpers + +```cpp +// Pack: collect bit `bit` from all 32 lanes into one uint32 +template +__device__ __forceinline__ void pack_kbit_warp(unsigned char qval, unsigned int* packed_words) { + for (int bit = 0; bit < K; bit++) + packed_words[bit] = __ballot_sync(0xFFFFFFFF, (qval >> bit) & 1); +} + +// Unpack: reconstruct K-bit index for this lane from K bit-plane words +template +__device__ __forceinline__ unsigned char unpack_kbit_warp(const unsigned int* packed_words, int lane_id) { + unsigned char val = 0; + for (int bit = 0; bit < K; bit++) + val |= ((packed_words[bit] >> lane_id) & 1) << bit; + return val; +} +``` + +The pack operation uses `__ballot_sync` which collects one bit from each of +the 32 lanes in a warp and assembles them into a single uint32 word. + +The unpack operation does the reverse: for a given lane position, it extracts +one bit from each of K plane words and assembles them into a K-bit index. + +Both operations are O(K) in ALU ops. For K=4: 4 ballot_sync ops for packing, +4 shift+mask+OR ops for unpacking. + +### 1.7 Template Instantiations + +Quantize: 12 variants (3 input types x 4 K values) +Dequantize: 24 variants (3 output types x 2 absmax types x 4 K values) + +All instantiated via macros at the bottom of ops.cu (lines 821-869). + +### 1.8 Python Bindings + +Three layers: +1. `bitsandbytes/_ops.py`: torch.library op definitions with fake tensor + implementations for torch.compile compatibility +2. `bitsandbytes/backends/cuda/ops.py`: CUDA kernel dispatch -- maps dtype to + C function name suffix, handles fp32->E4M4 absmax encoding +3. `csrc/pythonInterface.cpp`: unmangled C++ wrappers calling templates, + then extern "C" wrappers calling those + +The naming convention for C functions: +- Quantize: `cquantize_kbit_{fp16,bf16,fp32}_k{2,3,4,5}` +- Dequantize: `cdequantize_kbit_{fp16,bf16,fp32}_{u8abs,fp16abs}_k{2,3,4,5}` + +### 1.9 Test Coverage + +The test suite (`tests/test_kbit_quantization.py`, ~1400 lines) covers: +- Stage 0: Pure Python reference (quantize_kbit_ref, dequantize_kbit_ref) +- Stage 4: CUDA quantize correctness (absmax, all dtypes, various sizes) +- Stage 5: CUDA dequantize correctness (matches ref, all dtypes, various sizes, error bounds) +- Stage 6: Error analysis on 1M+ elements (analytical bounds, MSE scaling, SQNR) +- Stage 7: Cross-validation against existing NF4 +- Stage 8: Performance benchmarks (bandwidth utilization, throughput scaling, NF4 comparison) +- Python API tests (round-trip, all dtypes, custom codebook, various sizes) +- Output dtype correctness (bf16/fp32 vs fp16 baseline) +- Asymmetric codebook tests (all-positive, all-negative, skewed, non-uniform) +- E4M4 encode/decode tests (round-trip, subnormals, monotonicity, uniqueness) + +### 1.10 Memory Layout of Packed Data + +The quantize kernel stores packed data in flat sequential order: + +``` +packed_out[warp_id * K + bit] = plane_word + +For a tensor A of n elements: + num_blocks = ceil(n / 32) + packed_out has num_blocks * K uint32 words + + Block i covers elements [32*i, 32*(i+1)) + packed_out[i*K + 0] = bit-plane 0 of block i (bit 0 of all 32 elements) + packed_out[i*K + 1] = bit-plane 1 of block i + ... + packed_out[i*K + K-1] = bit-plane K-1 of block i +``` + +For a weight matrix W[K_dim, N] flattened in row-major order: + Element (k, n) is at flat index k * N + n + It belongs to block floor((k * N + n) / 32) + +This flat layout is NOT suitable for GEMM tiling. The repack kernel +(Section 4) transforms it into a tiled layout. + +--- + +## 2. Marlin Kernel Architecture (Reference) + +The Marlin kernel in vllm (`csrc/quantization/marlin/`) is a highly optimized +mixed-precision GEMM for weight-only quantization. We use it as architectural +reference, not as code to copy. + +### 2.1 Key Design Elements + +Location: `vllm/csrc/quantization/marlin/marlin_template.h` + +**Tiling and SM partitioning (line 271-281):** +Marlin uses "stripe" partitioning where each threadblock processes a +contiguous run of tiles from a linearized 2D work grid. This ensures +good SM utilization for all shapes while minimizing cross-threadblock +reductions. + +**4-stage async pipeline (line 916-923):** +Uses `cp.async` to overlap global->shared memory transfers with computation. +The `cp_async_wait()` pattern ensures double-buffering. + +**Register double-buffering (line 927-939):** +Shared memory reads alternate between two sets of register fragments +(`frag_b_quant[k%2]`), hiding the shared memory read latency. + +**On-the-fly dequantization (line 1236-1237):** +INT4/INT8/FP4/FP8 values are dequantized in registers using `lop3` and +`prmt` PTX instructions. This is purely arithmetic (no memory access). +For kbit, we replace this with codebook lookup (see Section 5). + +**Tensor core MMA (line 1278-1281):** +Standard `m16n8k16` instructions on dequantized fp16 fragments, +accumulating in fp32. + +**Scale application (line 1244-1270):** +Group-wise or channel-wise scales applied to dequantized FragB before MMA. +Multiple code paths handle different group_blocks configurations. +For kbit, this simplifies dramatically because our blocksize=32 aligns +with TILE_K boundaries (see Section 8.2). + +### 2.2 Marlin Stripe Partitioning + +The stripe system (marlin_template.h:271-281, marlin.cu:362-516) solves +the problem of filling all SMs when the 2D tile count is less than the +SM count. + +Example: 5 SMs, 3x3 tile grid (3 K-tiles x 3 N-columns): +``` +Column: 0 1 2 +K-tile 0: [0] [1] [3] +K-tile 1: [0] [2] [3] +K-tile 2: [1] [2] [4] +``` +Numbers = which SM handles that tile. + +The linearized tile sequence is distributed as contiguous "stripes" across +SMs. Properties: +- Perfect load balance (each SM gets total_tiles/num_SMs +/- 1) +- Minimized reductions (each SM crosses at most one column boundary) +- Adaptive split-K (automatically splits K when N-tiles < num_SMs) + +The reduction uses barrier_acquire/barrier_release on a locks array. + +We chose NOT to implement Marlin-style stripes. Instead, we use a persistent +kernel with explicit work assignment (see Section 6). + +### 2.3 Marlin Dispatch System + +Location: `marlin.cu:128-313` + +Two sets of thread configs: +- Small batch (thread_m_blocks=1): {128,128,256}, {64,128,128}, {128,64,128} +- Large batch (thread_m_blocks>1): {64,256,256}, {64,128,128}, {128,64,128} + (values are {thread_k, thread_n, num_threads}) + +The dispatch tries configs in priority order, picks the first valid one +(fits in shared memory, divides problem dimensions). If none work, reduces +thread_m_blocks and retries. + +For large M, Marlin splits M into parallel groups, each processed by a +separate set of SMs. + +### 2.4 Key Differences from kbit GEMM + +| Aspect | Marlin | kbit GEMM | +|-----------------------|----------------------------------|----------------------------------| +| Dequant method | lop3 bit manipulation -> fp16 | Bit extraction -> codebook lookup -> scale | +| Codebook | None (linear INT4->FP16) | 4-32 entries via __shfl_sync | +| Scale granularity | Configurable group_blocks | Fixed: 1 E4M4 scale per 32 elements | +| K-tile alignment | Complex group boundary logic | Clean: TILE_K=64 = 2 blocks, no straddling | +| B tile in shmem | Standard INT4 size | Same for K=4, smaller for K=2,3 | +| Bit widths | 4 or 8 | 2, 3, 4, 5 | +| Zero points | Optional, complex logic | None (symmetric codebook) | +| Act-order | Supported (major complexity) | Not needed | +| Work distribution | Stripe partitioning | Persistent kernel + atomicAdd | + +--- + +## 3. GEMM Kernel Design + +### 3.1 Problem Statement + +Compute `C[M, N] = A[M, K_dim] * W_kbit[K_dim, N]^T` where: +- A is in fp16 (or bf16) +- W is stored in kbit format (bit-plane packed indices + E4M4 absmax + codebook) +- C is in fp16 (or bf16) + +The weight matrix W is quantized offline and stored in a GEMM-optimized +tiled layout (produced by the repack kernel). The codebook is shared across +all blocks. + +### 3.2 Tile Sizes + +``` +TILE_M = variable (16, 32, 48, 64 depending on M; controlled by M_BLOCKS template param) +TILE_N = 128 (or 256 for large batch configs) +TILE_K = 64 (= 2 quantization blocks of 32 elements each) +``` + +TILE_K=64 was chosen over TILE_K=32 because: +- Doubles compute per shared memory load of A +- Better compute-to-load ratio in the transition zone (M=32-128) +- Only adds one extra absmax value per column per tile (trivial complexity) +- 2 MMA k-sub-tile pairs instead of 1, better pipeline utilization + +With TILE_K=64, each K-tile spans exactly 2 kbit blocks (each 32 elements). +Each column has 2 absmax values per K-tile. The absmax boundary falls exactly +between k_sub=1 and k_sub=2 of the 4 MMA k-sub-tiles. + +### 3.3 Thread Block Configuration + +256 threads = 8 warps per thread block. + +Warp layout (for TILE_M=64, TILE_N=128): + 2 warps along M x 4 warps along N + Each warp owns a 32x32 sub-tile of C + +For the m16n8k16 MMA instruction: + Each warp's 32x32 sub-tile = 2 M-blocks x 4 N-blocks = 8 MMA positions + With TILE_K=64 (4 k-sub-tiles of 16): 8 * 4 = 32 MMA ops per warp per K-tile + +### 3.4 Register Allocation + +Per thread: +- Codebook: 1 half register (loaded at kernel start, lives for entire kernel) +- FragC accumulators: M_BLOCKS * N_BLOCKS * 2 * Vec + For M_BLOCKS=4, N_BLOCKS=4: 32 * 4 = 128 floats = 512 bytes + Per thread: 512 / 32 = 16 floats +- FragA: M_BLOCKS * Vec per k-sub-tile (double-buffered) +- FragB: Vec per N-block per k-sub-tile (not stored, consumed immediately) +- Bit-plane words: K uint32 temporaries +- Absmax: 2 half values per column group + +Total estimated: ~40-50 registers per thread. Well within the 255 limit. + +--- + +## 4. Weight Storage Format and Repacking + +### 4.1 Quantization-Time Format + +The quantize kernel (`kQuantizeBlockwise_kbit`) outputs packed data in flat +sequential order: + +``` +For a weight matrix W[K_dim, N] flattened to 1D: + Block i: elements [32*i .. 32*(i+1)) + packed[i*K + bit] = bit-plane word for bit `bit` of block i + + absmax[i] = max absolute value in block i (float32, later E4M4-encoded) +``` + +This layout is contiguous in memory but NOT optimized for GEMM tiling. +A GEMM kernel loading a TILE_K x TILE_N region would need to gather from +many non-contiguous locations. + +### 4.2 GEMM-Optimized Tiled Format + +The repack kernel transforms the flat layout into a tiled layout where each +(k_tile, n_tile) region is contiguous in memory: + +``` +B_packed[k_tile][n_tile][col_within_tile][k_block_within_tile][bit_plane] + +Dimensions: + k_tile: 0 .. K_dim/TILE_K - 1 + n_tile: 0 .. N/TILE_N - 1 + col_within_tile: 0 .. TILE_N - 1 (128 columns per N-tile) + k_block_within_tile: 0 .. TILE_K/32 - 1 (2 blocks per K-tile with TILE_K=64) + bit_plane: 0 .. K-1 + +Total words per tile: TILE_N * (TILE_K / 32) * K + For TILE_N=128, TILE_K=64, K=4: 128 * 2 * 4 = 1024 uint32 words = 4 KB +``` + +Absmax is stored separately in a matching tiled layout: +``` +B_absmax[k_tile][n_tile][col_within_tile][k_block_within_tile] + +Total bytes per tile: TILE_N * (TILE_K / 32) = 128 * 2 = 256 bytes (uint8) +``` + +### 4.3 Repack Kernel + +The repack kernel is a simple gather/permutation kernel, run once when the +model is loaded (not on the hot path). It maps: + +``` +Source: packed_flat[block_id * K + bit] + where block_id = (k * N + n) / 32 (for element (k, n) in row-major W) + +Destination: packed_tiled[k_tile][n_tile][col][k_block][bit] + where k_tile = k / TILE_K + n_tile = n / TILE_N + col = n % TILE_N + k_block = (k % TILE_K) / 32 + bit = 0..K-1 +``` + +Similarly for absmax: +``` +Source: absmax_flat[block_id] +Destination: absmax_tiled[k_tile][n_tile][col][k_block] +``` + +The repack kernel should also handle E4M4 encoding of absmax if it hasn't +been done already. + +### 4.4 Why Bit-Plane Format (Not Contiguous Packing) + +We keep the bit-plane format for the GEMM kernel rather than converting to +contiguous K-bit packing. Reasons: + +1. **Uniform across all K values**: K=2,3,4,5 all work identically. Contiguous + packing is awkward for K=3,5 (don't divide 32 evenly, boundary-crossing + extraction needed). + +2. **Same memory footprint**: K words per block of 32 regardless of format. + Both formats use exactly K * 4 bytes per 32 elements. + +3. **Extraction cost is hidden**: The bit-plane extraction (K shift+mask+OR + per element) runs on INT32 ALU, concurrent with tensor core MMA. The + cost is effectively free in the steady state. + +4. **No format conversion needed**: The quantize kernel already produces + bit-planes. Repacking only changes the tile layout, not the data format. + +--- + +## 5. Inner Loop: Dequantization + MMA + +### 5.1 Tensor Core Fragment Layout + +For the `m16n8k16` MMA instruction (fp16 inputs, fp32 accumulation): + +The B matrix (weights) in the MMA is k=16 x n=8. Per thread t (lane 0-31): + +| Register | Row indices | Column | +|-----------|-----------------------------|---------| +| b[0] (half2) | k = 2*(t%4), 2*(t%4)+1 | n = t/4 | +| b[1] (half2) | k = 2*(t%4)+8, 2*(t%4)+9 | n = t/4 | + +Key property: all 4 elements a thread needs are in the SAME column (n = t/4). +The rows are at positions {2i, 2i+1, 2i+8, 2i+9} where i = t%4. + +This means threads 4n, 4n+1, 4n+2, 4n+3 all access the same column n. +When loading bit-plane words from shared memory, these 4 threads read the +same K addresses -> shared memory broadcast (no bank conflict). + +### 5.2 Bit-Plane Loading from Shared Memory + +In the standalone dequant kernel, bit-plane words are loaded from global +memory using the shuffle-broadcast trick (only lane `bit` loads, broadcasts +to all). This pattern DOES NOT WORK in the GEMM context because: + +1. Threads are not mapped 1:1 to elements -- they're mapped to tensor core + fragment positions. +2. Data is in shared memory (loaded by the async pipeline), not global memory. +3. Multiple threads need the same bit-plane words (4 threads per column). + +Instead, in the GEMM kernel, each thread reads K words directly from shared +memory for its column's block: + +```cpp +// my_col: which N-column this thread handles in the current MMA sub-tile +// This is determined by the tensor core fragment layout: my_col = lane_id / 4 +int my_col = (threadIdx.x % 32) / 4; // 0-7 for the 8 columns in m16n8k16 + +// Load K bit-plane words for this column's block +uint32_t planes[K_BITS]; +#pragma unroll +for (int b = 0; b < K_BITS; b++) + planes[b] = sh_b[column_offset + b]; +``` + +Since 4 threads share the same column (same `my_col` value), they all read +the same K addresses from shared memory. This is a 4-way broadcast, which +shared memory handles natively with no bank conflicts. + +With 8 distinct columns per warp and K=4: +- 8 groups of 4 threads, each reading from different addresses +- 8 different banks accessed simultaneously -> zero conflicts + +### 5.3 Index Extraction from Bit-Planes + +After loading the K bit-plane words into registers, each thread extracts +indices for its 4 fragment rows: + +```cpp +int row_base = 2 * (lane_id % 4); // 0, 2, 4, or 6 +int rows[4] = {row_base, row_base + 1, row_base + 8, row_base + 9}; + +half vals[4]; +#pragma unroll +for (int r = 0; r < 4; r++) { + int idx = 0; + #pragma unroll + for (int b = 0; b < K_BITS; b++) + idx |= ((planes[b] >> rows[r]) & 1) << b; + + // Codebook lookup + scale (see Section 5.4) + half cb_val = __shfl_sync(0xFFFFFFFF, cb_h, idx); + vals[r] = __hmul(cb_val, scale); +} + +// Pack into FragB +half2 frag_b[2]; +frag_b[0] = __halves2half2(vals[0], vals[1]); +frag_b[1] = __halves2half2(vals[2], vals[3]); +``` + +ALU cost per FragB (4 values, K=4): +- Index extraction: 4 elements * 4 bits = 16 shift+mask+OR ops (INT32) +- Codebook lookup: 4 __shfl_sync ops (shuffle unit) +- Scale: 4 __hmul ops (FP16 ALU) +- Pack: 2 __halves2half2 ops + +All of these run on different functional units from the tensor core MMA, +so they overlap with MMA execution. + +### 5.4 Codebook Lookup + +The codebook is stored as a half-precision value in each lane's register: + +```cpp +// At kernel start (once): +int lane = threadIdx.x % 32; +half cb_h = (lane < (1 << K_BITS)) + ? __float2half(codebook[lane]) + : __float2half(0.0f); +``` + +Lookup uses `__shfl_sync` with per-thread independent source lane: + +```cpp +half val = __shfl_sync(0xFFFFFFFF, cb_h, idx); +``` + +Each thread can request the value from any lane. The shuffle unit handles +arbitrary per-thread source selection. Cost: 1 cycle, no memory access. + +Why shuffle (not constant memory or shared memory): +- Constant memory: optimized for broadcast (all threads same address). + With divergent indices (each thread wants a different codebook entry), + it serializes -- up to 2^K sequential reads. Bad. +- Shared memory: works (no bank conflicts for K<=4 since entries fit in + distinct banks), but adds shared memory traffic. +- Shuffle: 1 cycle, zero memory, perfect for this use case. Already + proven in the existing dequant kernel. + +### 5.5 Complete Dequant + MMA Sequence + +For one K-tile (TILE_K=64, 4 sub-tiles of k=16): + +```cpp +for (int k_sub = 0; k_sub < 4; k_sub++) { + // Which kbit block does this sub-tile fall in? + // k_sub 0,1 -> block 0 (first 32 elements), k_sub 2,3 -> block 1 + half scale = (k_sub < 2) ? absmax_h[0] : absmax_h[1]; + + // Load A fragments via ldmatrix (from shared memory) + FragA frag_a[M_BLOCKS]; + for (int m = 0; m < M_BLOCKS; m++) + ldmatrix_a(frag_a[m], sh_a, m, k_sub); + + // For each N-block in this warp's sub-tile: + for (int n = 0; n < N_BLOCKS; n++) { + // Load bit-plane words from shared memory + uint32_t planes[K_BITS]; + load_b_planes(planes, sh_b, n, k_sub); + + // Dequant: extract indices, codebook lookup, scale + half2 frag_b[2]; + dequant_kbit_fragb(planes, scale, cb_h, frag_b); + + // MMA: accumulate across all M-blocks (A fragments reused) + for (int m = 0; m < M_BLOCKS; m++) { + mma_m16n8k16(frag_a[m], frag_b, frag_c[m][n]); + } + } +} +``` + +The key data reuse pattern: +- A fragments: loaded once per M-block, reused across all N-blocks +- B fragments: dequantized once per N-block, reused across all M-blocks +- Codebook register: loaded once at kernel start, reused forever +- Absmax: decoded once per block-of-32 per column, reused across M-blocks + +--- + +## 6. Persistent Kernel and Work Distribution + +### 6.1 Why Persistent Kernel + +For typical LLM shapes (N=4096-16384, M variable, K=4096-16384), the number +of M-tiles * N-tiles is often less than the number of SMs: + +| M | N | M/64 x N/128 | H100 SMs | Utilization | +|-----|------|--------------|----------|-------------| +| 128 | 4096 | 2 x 32 = 64 | 132 | 48% | +| 256 | 4096 | 4 x 32 = 128| 132 | 97% | +| 128 | 8192 | 2 x 64 = 128| 132 | 97% | + +When utilization is below ~80%, we need split-K (multiple blocks share the +same output tile, each handling a portion of K). The persistent kernel handles +this naturally. + +### 6.2 Design + +Launch exactly `num_SMs` blocks. Each block loops over assigned work items. +Work items are linearized as (m_tile, n_tile, k_chunk) triples: + +``` +Total work = m_tiles * n_tiles * k_chunks + where k_chunks = ceil(K_dim / TILE_K / tiles_per_chunk) + and tiles_per_chunk >= 8 (minimum for pipeline efficiency) + +Work items are ordered so that all k_chunks for a given (m_tile, n_tile) +are contiguous in the linearized sequence. +``` + +Each block gets a contiguous range of work items: +```cpp +int total_work = m_tiles * n_tiles * k_chunks; +int work_per_block = div_ceil(total_work, gridDim.x); +int my_start = blockIdx.x * work_per_block; +int my_end = min(my_start + work_per_block, total_work); +``` + +### 6.3 Accumulator Management + +When consecutive work items for a block share the same output tile +(same m_tile, n_tile), the accumulators persist across k_chunks. +The block accumulates without writing to memory. + +When the output tile changes (or at the end), the block writes results: + +```cpp +int prev_mn = -1; +FragC frag_c[M_BLOCKS][N_BLOCKS][2]; + +for (int work_id = my_start; work_id < my_end; work_id++) { + int mn_id = work_id / k_chunks; + int k_chunk_id = work_id % k_chunks; + + if (mn_id != prev_mn) { + if (prev_mn >= 0) + write_output(frag_c, prev_mn, ...); + zero_accumulators(frag_c); + prev_mn = mn_id; + } + + // Process K-tiles for this chunk + process_k_range(k_chunk_id, frag_c, ...); +} + +// Write final tile +if (prev_mn >= 0) + write_output(frag_c, prev_mn, ...); +``` + +### 6.4 Output Write Strategy + +Three cases based on whether the block owns the full K-range for its output tile: + +```cpp +bool i_own_k_start = (my_first_k_chunk == 0); +bool i_own_k_end = (my_last_k_chunk == k_chunks - 1); + +if (i_own_k_start && i_own_k_end) { + // Full ownership: write fp16 directly to C + write_frag_fp16(frag_c, C, ...); +} +else if (i_own_k_start) { + // First contributor: overwrite fp32 workspace (acts as zero + write) + write_frag_fp32(frag_c, C_workspace, ...); +} +else { + // Subsequent contributor: atomicAdd fp32 + atomic_add_frag_fp32(frag_c, C_workspace, ...); +} +``` + +No separate memset is needed: the first contributor overwrites the workspace. + +### 6.5 Final Reduction + +When multiple blocks share an output tile, the last block to finish converts +fp32 workspace to fp16 output. This is detected via an atomic counter: + +```cpp +// Per-tile done counter (in the workspace/locks array) +if (not_full_ownership) { + int count = atomicAdd(&tile_done_count[mn_id], 1); + if (count == num_contributors - 1) { + // I'm the last one: convert fp32 -> fp16 + convert_tile_fp32_to_fp16(C_workspace, C, mn_id, ...); + } +} +``` + +The tile_done_count array is tiny: m_tiles * n_tiles ints. + +### 6.6 Pipeline Restart at Tile Boundaries + +When a block switches to a new (m_tile, n_tile) or a new k_chunk, the +pipeline must restart (new data in shared memory). This costs ~2 K-tiles +of pipeline fill time. Within a block's k_chunk, K-tiles are processed +sequentially with continuous pipeline operation. + +This is the main performance overhead of split-K: each split incurs a +pipeline restart. With >= 8 K-tiles per chunk, the overhead is <= 25%. +Typical values (16-32 K-tiles per chunk) give 6-12% overhead. + +### 6.7 Split-K=1 Fast Path + +When m_tiles * n_tiles >= num_SMs, no split-K is needed. Each block owns +complete output tiles and writes fp16 directly. No fp32 workspace, no +atomics, no reduction. This is the common case for large M. + +--- + +## 7. Pipeline and Shared Memory + +### 7.1 Shared Memory Layout + +``` +Per pipeline stage: ++-------------------------------------------+ +| A tile: TILE_M * TILE_K * 2 bytes (fp16) | +| For TILE_M=64, TILE_K=64: 8 KB | ++-------------------------------------------+ +| B tile (packed bit-planes): | +| TILE_N * (TILE_K/32) * K * 4 bytes | +| For TILE_N=128, K=4: 4 KB | ++-------------------------------------------+ +| Absmax (E4M4): | +| TILE_N * (TILE_K/32) * 1 byte | +| = 256 bytes | ++-------------------------------------------+ + +Total per stage (TILE_M=64, K=4): ~12.3 KB +With 2 stages (double buffer): ~24.6 KB +With 4 stages: ~49.2 KB + +GPU shared memory limits: + A100: 164 KB per SM + H100: 228 KB per SM + 4090: 100 KB per SM + +Even with 4 stages, we have ample room. +``` + +The compressed B tiles are 2-8x smaller than fp16 would be, which means: +- More pipeline stages fit in shared memory (better latency hiding) +- Or larger tiles fit (better compute efficiency) + +### 7.2 Pipeline Structure + +Double-buffered pipeline with cp.async: + +```cpp +// Initial fill +fetch_tile_to_shared(/*stage=*/0, k_tile_start); +fetch_tile_to_shared(/*stage=*/1, k_tile_start + 1); +cp_async_fence(); + +for (int kt = k_tile_start; kt < k_tile_end; kt++) { + int stage = (kt - k_tile_start) % 2; + + cp_async_wait<1>(); // wait for current stage + __syncthreads(); + + // Prefetch next tile + if (kt + 2 < k_tile_end) { + fetch_tile_to_shared((kt + 2) % 2, kt + 2); + } + cp_async_fence(); + + // Process: dequant + MMA for current tile + process_k_tile(stage, frag_c, cb_h); +} + +cp_async_wait<0>(); +__syncthreads(); +``` + +### 7.3 Fetch Functions + +```cpp +__device__ void fetch_tile_to_shared(int stage, int k_tile) { + int4* sh_a = sh_a_base + stage * a_stage_words; + uint32_t* sh_b = sh_b_base + stage * b_stage_words; + uint8_t* sh_abs = sh_abs_base + stage * abs_stage_bytes; + + // Load A tile: TILE_M * TILE_K / 8 int4 loads + // 256 threads, each loads ceil(A_size / 256) int4 words + for (int i = threadIdx.x; i < a_tile_int4s; i += blockDim.x) { + cp_async4(&sh_a[i], &A_global[a_offset + i]); + } + + // Load B tile (packed): much smaller than A + for (int i = threadIdx.x; i < b_tile_int4s; i += blockDim.x) { + if (i < actual_b_words) + cp_async4(&sh_b_int4[i], &B_global[b_offset + i]); + } + + // Load absmax: very small (256 bytes) + if (threadIdx.x < abs_tile_int4s) { + cp_async4(&sh_abs_int4[threadIdx.x], &absmax_global[abs_offset + threadIdx.x]); + } +} +``` + +Note the asymmetry: A loading dominates bandwidth, B loading is "free" +relative to A. This is a key advantage of compressed weights. + +### 7.4 Bank Conflict Analysis + +**A tile reads (via ldmatrix):** Standard ldmatrix access pattern, +well-studied, no conflicts with standard swizzled layout. + +**B tile reads (bit-plane words):** As analyzed in Section 5.2, 4 threads +per column group read the same addresses (broadcast), 8 column groups read +different addresses (different banks). Zero conflicts. + +**Absmax reads:** Each thread reads one uint8 for its column. With 8 columns +per warp, these are at different byte addresses. No conflicts. + +--- + +## 8. Codebook and Absmax Handling + +### 8.1 Codebook Precision + +The existing dequant kernel uses float32 codebook values. For the GEMM kernel, +we convert to half at kernel start: + +```cpp +half cb_h = (lane < (1 << K_BITS)) + ? __float2half(codebook[lane]) + : __float2half(0.0f); +``` + +Rationale: +- Codebook values are in [-1, 1], well within half precision +- The MMA instruction takes fp16 inputs anyway +- Avoids float->half conversion in the inner loop (4 conversions per FragB) +- MMA accumulates in fp32, so precision loss in fp16 fragments is minimal +- The quantization error itself (~6% for K=4) dominates any fp16 rounding + +### 8.2 Absmax Decode and Application + +With TILE_K=64, each K-tile spans exactly 2 kbit blocks. Each column has +exactly 2 absmax values per K-tile. This is much simpler than Marlin's +group boundary logic because there's no straddling -- the boundaries are +always at fixed positions. + +```cpp +// Load 2 absmax values from shared memory for this column +uint8_t raw0 = sh_absmax[my_col * 2 + 0]; // block 0 (k=0..31) +uint8_t raw1 = sh_absmax[my_col * 2 + 1]; // block 1 (k=32..63) + +// Decode E4M4 -> half (done once per column per K-tile) +half scale0 = __float2half(decode_e4m4_absmax(raw0)); +half scale1 = __float2half(decode_e4m4_absmax(raw1)); + +// In the sub-tile loop: +for (int k_sub = 0; k_sub < 4; k_sub++) { + half scale = (k_sub < 2) ? scale0 : scale1; + // ... dequant uses __hmul(codebook_val, scale) ... +} +``` + +The decode is ~5 integer ALU ops, done twice per column per K-tile, +shared across all M-rows. Negligible cost. + +### 8.3 Absmax as Group Scale + +The per-block absmax is functionally identical to Marlin's group scale +mechanism. In Marlin terminology: +- group_size = 32 (our blocksize) +- group_blocks = TILE_K / 32 = 2 (number of groups per K-tile) + +But our implementation is much simpler because: +1. No activation reordering (act-order) to worry about +2. Group boundaries always align with K-tile boundaries +3. No zero-point subtraction +4. Scale format is fixed (E4M4 uint8) + +--- + +## 9. Performance Analysis + +### 9.1 Arithmetic Intensity + +Per thread block per K-tile: +- Compute: 8 warps * 32 MMA ops * 256 FMA ops = 65,536 FMAs = 131,072 FLOPs + (with TILE_K=64, this doubles to 262,144 FLOPs) +- Memory: + - A: TILE_M * TILE_K * 2 bytes = 64 * 64 * 2 = 8,192 bytes + - B: TILE_N * (TILE_K/32) * K * 4 = 128 * 2 * 4 * 4 = 4,096 bytes (K=4) + - Absmax: TILE_N * (TILE_K/32) = 128 * 2 = 256 bytes + - Total: 12,544 bytes + +Arithmetic intensity: 262,144 / 12,544 = 20.9 FLOP/byte + +Compare fp16 GEMM (same tiles, B in fp16): +- B would be: 128 * 64 * 2 = 16,384 bytes +- Total: 24,832 bytes +- Intensity: 262,144 / 24,832 = 10.6 FLOP/byte + +The kbit kernel has ~2x higher arithmetic intensity for the same tile size. + +### 9.2 Compute-Bound Threshold + +On H100 (990 TFLOPS fp16 tensor, 3.35 TB/s HBM): +Compute-bound threshold: 990e12 / 3.35e12 = 295 FLOP/byte + +For C[M, 4096] = A[M, 4096] * W[4096, 4096] with K=4: +- FLOPs: 2 * M * 4096 * 4096 +- Bytes: M * 4096 * 2 (A) + 4096 * 4096 * 0.53 (B, K=4 + E4M4) + M * 4096 * 2 (C) + +Solving for compute-bound threshold: +- M=1: intensity ~3, memory-bound +- M=32: intensity ~93, memory-bound +- M=128: intensity ~296, at the boundary +- M=256: intensity ~465, compute-bound + +For M >= ~128 on H100, we're compute-bound and tensor core utilization +determines performance. + +### 9.3 Expected Performance vs Marlin Stripes + +The persistent kernel with explicit work distribution loses ~5-15% vs +Marlin-style stripes in unfavorable cases. The overhead comes from: + +1. Pipeline startup/drain: 2 K-tiles overhead per k_chunk. + With >= 8 tiles per chunk: <= 25% overhead on the chunked portion. + Typical: 6-12%. + +2. Tail-wave imbalance: last wave of blocks may not fill all SMs. + Typically 0-5%. + +3. AtomicAdd reduction: < 1% (negligible on Ampere+). + +For K=4096 with split_k effective=2-4: expect ~10% overhead. +For K=8192+ or when no split-K needed: ~0-3% overhead. +This is acceptable given the massive implementation simplicity gain. + +### 9.4 Effective Bits Per Weight Element + +``` +K=2: 2/8 + 1/32 = 0.28125 bytes/element (7.1x compression vs fp16) +K=3: 3/8 + 1/32 = 0.40625 bytes/element (4.9x compression) +K=4: 4/8 + 1/32 = 0.53125 bytes/element (3.8x compression) +K=5: 5/8 + 1/32 = 0.65625 bytes/element (3.0x compression) + +(The 1/32 term is the E4M4 absmax overhead: 1 byte per 32 elements) +``` + +--- + +## 10. Kernel Dispatch and Python Integration + +### 10.1 Host-Side Dispatch + +```cpp +void kbit_gemm( + const half* A, // [M, K_dim] row-major + const uint32_t* B, // tiled kbit packed data + half* C, // [M, N] row-major + float* C_workspace, // [M, N] fp32 workspace (for split-K) + int* tile_counters, // [m_tiles * n_tiles] atomic counters + const uint8_t* absmax, // tiled E4M4 absmax + const float* codebook, // [2^K] float32 codebook + int M, int N, int K_dim, int K_bits, + cudaStream_t stream) +{ + int dev; + cudaGetDevice(&dev); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int max_shmem; + cudaDeviceGetAttribute(&max_shmem, + cudaDevAttrMaxSharedMemoryPerBlockOption, dev); + + // Choose M-blocking + int m_blocks; + if (M <= 16) m_blocks = 1; + else if (M <= 32) m_blocks = 2; + else if (M <= 48) m_blocks = 3; + else m_blocks = 4; + int tile_m = m_blocks * 16; + + // Choose tile config + struct Config { int tile_k, tile_n, threads; }; + Config cfg = select_config(m_blocks, M, N, K_dim, K_bits, max_shmem); + + // Compute work distribution + int m_tiles = div_ceil(M, tile_m); + int n_tiles = N / cfg.tile_n; + int k_tiles = K_dim / cfg.tile_k; + int min_tiles_per_chunk = 8; + int k_chunks = max(1, div_ceil(k_tiles, max(min_tiles_per_chunk, + div_ceil(k_tiles * m_tiles * n_tiles, sms) /* target full occupancy */))); + + // Zero tile counters if split-K + bool needs_split_k = (m_tiles * n_tiles * k_chunks > m_tiles * n_tiles); + if (needs_split_k) { + cudaMemsetAsync(tile_counters, 0, m_tiles * n_tiles * sizeof(int), stream); + } + + // Launch persistent kernel + int shmem_size = compute_shmem(cfg, m_blocks, K_bits); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + // Dispatch on K_bits and m_blocks + dispatch_kernel(K_bits, m_blocks, cfg, sms, shmem_size, stream, ...); +} +``` + +### 10.2 Config Selection + +Priority-ordered configs for small and large batch: + +```cpp +// Small batch (m_blocks == 1): +Config small_configs[] = { + {64, 128, 256}, // balanced + {64, 128, 128}, // fewer threads, less shmem + {32, 128, 128}, // shallow K, tight shmem +}; + +// Large batch (m_blocks > 1): +Config large_configs[] = { + {64, 256, 256}, // wide N, maximum output parallelism + {64, 128, 256}, // balanced + {64, 128, 128}, // fallback +}; +``` + +Validation: config must fit in shared memory and divide problem dimensions. + +### 10.3 Python Binding + +Following the existing pattern in `bitsandbytes/_ops.py`: + +```python +torch.library.define( + "bitsandbytes::kbit_gemm", + "(Tensor A, Tensor B_packed, Tensor absmax, Tensor codebook, " + "int k, int N, int K_dim) -> Tensor", +) +``` + +CUDA backend in `bitsandbytes/backends/cuda/ops.py`: +```python +@register_kernel("bitsandbytes::kbit_gemm", "cuda") +def _(A, B_packed, absmax, codebook, k, N, K_dim): + M = A.shape[0] + C = torch.empty(M, N, dtype=A.dtype, device=A.device) + # ... allocate workspace, call C function ... + return C +``` + +### 10.4 Repack API + +```python +torch.library.define( + "bitsandbytes::kbit_repack_for_gemm", + "(Tensor packed_flat, Tensor absmax_flat, int K_dim, int N, int k, " + "int tile_k, int tile_n) -> (Tensor, Tensor)", +) +``` + +This would be called once when loading a model, before inference begins. + +--- + +## 11. File Organization and Build + +### 11.1 Kernel Location + +The GEMM kernel should go in `csrc/kernels.cu` (the standard location for +CUDA kernels in bitsandbytes), NOT in `csrc/ops.cu`. + +Background: The existing kbit quantize/dequantize kernels were placed in +`ops.cu` to avoid RDC (relocatable device code) linking issues with template +instantiations. This was a workaround, not a deliberate architectural choice. +The `CUDA_RESOLVE_DEVICE_SYMBOLS ON` flag was added to CMakeLists.txt as +part of that workaround and should be removed. + +For the GEMM kernel: place the kernel definition and launch wrapper in +`csrc/kernels.cu` with declarations in `csrc/kernels.cuh`. The extern "C" +wrappers go in `csrc/pythonInterface.cpp` following the existing pattern. + +### 11.2 CMakeLists.txt + +Remove the `CUDA_RESOLVE_DEVICE_SYMBOLS ON` flag that was added as a +workaround. The GEMM kernel doesn't need it if templates are properly +instantiated in the same compilation unit as their declarations. + +### 11.3 New Files + +No new .cu files needed. The GEMM kernel fits naturally in the existing +file structure: +- Kernel code: `csrc/kernels.cu` (append) +- Kernel declarations: `csrc/kernels.cuh` (append) +- Launch wrappers: `csrc/ops.cu` (append, for the host-side dispatch) +- C interface: `csrc/pythonInterface.cpp` (append) +- Python ops: `bitsandbytes/_ops.py` (append) +- CUDA backend: `bitsandbytes/backends/cuda/ops.py` (append) +- Tests: `tests/test_kbit_gemm.py` (new) + +### 11.4 Template Instantiation Strategy + +The GEMM kernel is templated on: +- K_BITS: 2, 3, 4, 5 +- M_BLOCKS: 1, 2, 3, 4 +- Tile config (TILE_K, TILE_N): 2-3 configs + +Total: 4 * 4 * 3 = 48 kernel variants (worst case). +This is manageable. Marlin has hundreds of variants. + +Instantiation via macros, similar to existing pattern: +```cpp +#define INSTANTIATE_KBIT_GEMM(K, M_BLOCKS, TILE_K, TILE_N) \ + template __global__ void kbit_gemm_kernel(...); + +INSTANTIATE_KBIT_GEMM(2, 1, 64, 128) +INSTANTIATE_KBIT_GEMM(2, 2, 64, 128) +// ... etc +``` + +--- + +## 12. Error Budget + +### 12.1 Error Sources + +The existing test suite establishes the combined error bound per block: + +``` +max_error <= (max_gap/2 + 1/16) * absmax + epsilon + +where: + max_gap: maximum gap between adjacent codebook entries + 1/16: maximum relative error from E4M4 absmax encoding + absmax: absolute maximum of the block + epsilon: small constant for floating-point rounding (~1e-6) +``` + +The GEMM kernel introduces no new error sources beyond the standalone dequant: +- Same bit-plane extraction (exact) +- Same codebook lookup (exact, via shuffle) +- Same absmax multiply (same precision) +- fp16 codebook storage adds at most 1 ULP of fp16 (~0.001 for values near 1.0) +- MMA accumulates in fp32 (no precision loss in accumulation) + +### 12.2 SQNR Expectations + +From the test suite (1M elements, normal distribution): +- K=2: SQNR > 5 dB +- K=3: SQNR > 10 dB +- K=4: SQNR > 15 dB +- K=5: SQNR > 20 dB + +E4M4 absmax degrades SQNR by < 1.5 dB vs fp32 absmax. + +The GEMM kernel should match these bounds exactly, since the dequant +logic is identical. + +--- + +## 13. Template Instantiations + +### 13.1 Kernel Template + +```cpp +template +__global__ void kbit_gemm_kernel( + const half* __restrict__ A, + const uint32_t* __restrict__ B_packed, + half* __restrict__ C, + float* __restrict__ C_workspace, + int* __restrict__ tile_counters, + const uint8_t* __restrict__ B_absmax, + const float* __restrict__ codebook, + int M, int N, int K_dim, + int m_tiles, int n_tiles, int k_chunks, + int tiles_per_chunk); +``` + +### 13.2 Repack Kernel Template + +```cpp +template +__global__ void kbit_repack_kernel( + const uint32_t* __restrict__ packed_flat, + const uint8_t* __restrict__ absmax_flat, + uint32_t* __restrict__ packed_tiled, + uint8_t* __restrict__ absmax_tiled, + int K_dim, int N); +``` + +--- + +## 14. Future Considerations + +### 14.1 Hopper (sm_90) Optimizations + +On Hopper GPUs, warp specialization can be used: producer warps handle +data loading (using TMA for efficient async copies), consumer warps handle +compute. The producer warps could handle the bit-plane loading and even +partial dequantization, feeding pre-dequantized fp16 tiles to consumer +warps. This would further overlap memory and compute. + +### 14.2 Larger Block Sizes + +The current kbit implementation uses blocksize=32 (warp-size). Larger +block sizes (64, 128) would reduce the absmax overhead but require +different packing primitives (can't use single-warp __ballot_sync for +blocks > 32). This would be a separate project. + +### 14.3 Activation Quantization (W_kbit * A_kbit) + +If activations are also kbit-quantized, the GEMM becomes a fully quantized +matmul. This would require a different kernel architecture (integer MMA +or custom accumulation). + +### 14.4 Fused Operations + +Common fused patterns for inference: +- kbit GEMM + bias add +- kbit GEMM + ReLU/GELU +- kbit GEMM + residual add + +These can be added as epilogue options in the kernel template, similar to +Marlin's bias support. + +### 14.5 Batched GEMM + +For attention computation, batched GEMM (multiple independent GEMMs) may +be needed. The persistent kernel can be extended to handle batches by adding +a batch dimension to the work assignment. + +--- + +## Appendix A: Marlin Code References + +Key files in `~/git/vllm/csrc/quantization/marlin/`: +- `marlin_template.h`: Main kernel template (~2070 lines) + - Line 271-281: Stripe partitioning explanation + - Line 362-401: Work distribution setup + - Line 916-923: Pipeline wait/fence + - Line 927-939: Register fetch from shared memory + - Line 1167-1285: matmul() inner loop with dequant + scale + MMA + - Line 1780-1813: Main K-loop with pipeline interleaving + - Line 1839-2068: Output reduction and slice management +- `marlin.cu`: Host dispatch (~530 lines) + - Line 128-141: Thread config tables + - Line 179-249: Config validation + - Line 265-313: Config selection + - Line 315-527: Main dispatch function +- `marlin_mma.h`: MMA instruction wrappers +- `dequant.h`: Dequantization functions (lop3-based) +- `marlin.cuh`: Constants and helpers + +## Appendix B: Glossary + +- **Block (quantization)**: A group of 32 consecutive elements sharing one absmax value +- **Block (CUDA)**: A CUDA thread block (256 threads = 8 warps) +- **Bit-plane**: A uint32 word containing one bit from each of 32 elements +- **FragA, FragB, FragC**: Register fragments for tensor core MMA +- **MMA**: Matrix multiply-accumulate (tensor core instruction) +- **m16n8k16**: MMA instruction computing a 16x8 output from 16x16 and 16x8 inputs +- **Split-K**: Partitioning the K (reduction) dimension across multiple thread blocks +- **Tile**: A sub-matrix processed by one thread block or one MMA instruction +- **TILE_K, TILE_M, TILE_N**: Thread block tile dimensions +- **Persistent kernel**: A kernel that launches exactly num_SMs blocks, each looping over work +- **E4M4**: 8-bit float format with 4-bit exponent and 4-bit mantissa +- **Codebook**: A lookup table of 2^K reconstruction values for quantization +- **absmax**: Per-block absolute maximum, used as scale factor +- **Normal-float**: Quantization levels placed at quantiles of N(0,1) diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 532fe7afa..435171d54 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -431,3 +431,74 @@ def _( qmap2.dtype == absmax2.dtype == torch.float32, lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", ) + + +# K-bit blockwise quantization (K=2..5, blocksize=32) + +torch.library.define( + "bitsandbytes::quantize_kbit", + "(Tensor A, Tensor codebook, int k) -> (Tensor, Tensor)", +) + + +@register_fake("bitsandbytes::quantize_kbit") +def _(A: torch.Tensor, codebook: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}") + torch._check(codebook.numel() == (1 << k), lambda: f"codebook must have {1 << k} entries for k={k}") + n = A.numel() + num_blocks = -(n // -32) + # packed: num_blocks * k int32 words + k padding words + packed = torch.empty(num_blocks * k + k, device=A.device, dtype=torch.int32) + absmax = torch.empty(num_blocks + 1, device=A.device, dtype=torch.float32) + return packed, absmax + + +torch.library.define( + "bitsandbytes::dequantize_kbit", + "(Tensor packed, Tensor codebook, Tensor absmax, int k, int n, ScalarType dtype) -> Tensor", +) + + +@register_fake("bitsandbytes::dequantize_kbit") +def _( + packed: torch.Tensor, + codebook: torch.Tensor, + absmax: torch.Tensor, + k: int, + n: int, + dtype: torch.dtype, +) -> torch.Tensor: + torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}") + torch._check( + absmax.dtype in (torch.float32, torch.uint8), + lambda: f"absmax must be float32 or uint8 (E4M4), got {absmax.dtype}", + ) + num_blocks = -(n // -32) + return torch.empty(num_blocks * 32, device=packed.device, dtype=dtype) + + +torch.library.define( + "bitsandbytes::dequantize_kbit_", + "(Tensor packed, Tensor codebook, Tensor absmax, int k, int n, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", +) + + +@register_fake("bitsandbytes::dequantize_kbit_") +def _( + packed: torch.Tensor, + codebook: torch.Tensor, + absmax: torch.Tensor, + k: int, + n: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> torch.Tensor: + torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}") + torch._check( + absmax.dtype in (torch.float32, torch.uint8), + lambda: f"absmax must be float32 or uint8 (E4M4), got {absmax.dtype}", + ) + num_blocks = -(n // -32) + torch._check(out.numel() >= num_blocks * 32, lambda: f"out must have at least {num_blocks * 32} elements") + torch._check(out.dtype == dtype, lambda: f"out dtype {out.dtype} must match requested dtype {dtype}") + return out diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index d92f9a490..f81a270e3 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -764,3 +764,117 @@ def _optimizer_update_8bit_blockwise_impl( register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl) register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl) + + +# K-bit blockwise quantization (K=2..5, blocksize=32) + +_KBIT_DTYPE_SUFFIX = { + torch.float16: "fp16", + torch.bfloat16: "bf16", + torch.float32: "fp32", +} + + +@register_kernel("bitsandbytes::quantize_kbit", "cuda") +def _(A: torch.Tensor, codebook: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}") + torch._check( + A.dtype in _KBIT_DTYPE_SUFFIX, + lambda: f"quantize_kbit only supports float16/bfloat16/float32, got {A.dtype}", + ) + torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}") + torch._check(codebook.numel() == (1 << k), lambda: f"codebook must have {1 << k} entries for k={k}") + + n = A.numel() + num_blocks = -(n // -32) + packed = torch.zeros(num_blocks * k + k, device=A.device, dtype=torch.int32) + absmax = torch.zeros(num_blocks + 1, device=A.device, dtype=torch.float32) + + with _cuda_device_of(A): + tname = _KBIT_DTYPE_SUFFIX[A.dtype] + fn = getattr(lib, f"cquantize_kbit_{tname}_k{k}") + fn( + get_ptr(codebook), + get_ptr(A), + get_ptr(absmax), + get_ptr(packed), + ct.c_int(n), + ) + + return packed, absmax + + +_KBIT_ABSMAX_SUFFIX = { + torch.uint8: "u8abs", + torch.float16: "fp16abs", +} + + +def _dequantize_kbit_impl( + packed: torch.Tensor, + codebook: torch.Tensor, + absmax: torch.Tensor, + k: int, + n: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}") + torch._check( + dtype in _KBIT_DTYPE_SUFFIX, + lambda: f"dequantize_kbit only supports float16/bfloat16/float32, got {dtype}", + ) + torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}") + torch._check( + absmax.dtype in (torch.float32, torch.float16, torch.uint8), + lambda: f"absmax must be float32, float16, or uint8 (E4M4), got {absmax.dtype}", + ) + + # If fp32 absmax, encode to E4M4 first + if absmax.dtype == torch.float32: + from bitsandbytes.functional import encode_absmax_e4m4 + + absmax = encode_absmax_e4m4(absmax) + + tname = _KBIT_DTYPE_SUFFIX[dtype] + aname = _KBIT_ABSMAX_SUFFIX[absmax.dtype] + + with _cuda_device_of(packed): + fn = getattr(lib, f"cdequantize_kbit_{tname}_{aname}_k{k}") + fn( + get_ptr(packed), + get_ptr(codebook), + get_ptr(absmax), + get_ptr(out), + ct.c_int(n), + _get_tensor_stream(packed), + ) + + +@register_kernel("bitsandbytes::dequantize_kbit", "cuda") +def _( + packed: torch.Tensor, + codebook: torch.Tensor, + absmax: torch.Tensor, + k: int, + n: int, + dtype: torch.dtype, +) -> torch.Tensor: + num_blocks = -(n // -32) + out = torch.empty(num_blocks * 32, device=packed.device, dtype=dtype) + _dequantize_kbit_impl(packed, codebook, absmax, k, n, dtype, out) + return out + + +@register_kernel("bitsandbytes::dequantize_kbit_", "cuda") +def _( + packed: torch.Tensor, + codebook: torch.Tensor, + absmax: torch.Tensor, + k: int, + n: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> torch.Tensor: + _dequantize_kbit_impl(packed, codebook, absmax, k, n, dtype, out) + return out diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index bca3dd66d..b3de9d1c0 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1005,6 +1005,213 @@ def dequantize_4bit( return out +# --------------------------------------------------------------------------- +# K-bit blockwise quantization (K=2..5, blocksize=32) +# --------------------------------------------------------------------------- + +# Cache for precomputed normal-float codebooks (K -> Tensor on each device) +_kbit_codebook_cache: dict[tuple[int, torch.device], torch.Tensor] = {} + + +def create_normal_float_codebook(k: int, device=None) -> torch.Tensor: + """Create a 2^k-entry normal-float codebook (quantiles of N(0,1), normalized to [-1, 1]). + + For k bits we have 2^k reconstruction levels placed at the expected values + of N(0,1) within 2^k equiprobable bins. The result is sorted ascending + and normalized so the largest magnitude is 1.0. + + Args: + k: Bit width (2-5). + device: Target device. Defaults to "cuda". + + Returns: + Float32 tensor of shape (2^k,) with values in [-1, 1]. + """ + try: + from scipy.stats import norm + except ImportError as ie: + raise ImportError( + "Scipy is required for `create_normal_float_codebook`. Install `bitsandbytes` with the `[test]` extra.", + ) from ie + + if device is None: + device = torch.device("cuda") + device = torch.device(device) + + cache_key = (k, device) + if cache_key in _kbit_codebook_cache: + return _kbit_codebook_cache[cache_key] + + n_levels = 1 << k + quantiles = torch.linspace(0.5 / n_levels, 1.0 - 0.5 / n_levels, n_levels) + values = torch.tensor(norm.ppf(quantiles.numpy()), dtype=torch.float32) + values = values / values.abs().max() + values = values.to(device) + + _kbit_codebook_cache[cache_key] = values + return values + + +def encode_absmax_e4m4(absmax: Tensor, bias: int = 11) -> Tensor: + """Encode fp32 absmax values to uint8 using E4M4 micro-float format. + + Format: 4-bit exponent + 4-bit mantissa with IEEE-style subnormals. + Normal (e > 0): 2^(e - bias) * (1 + m/16) + Subnormal (e = 0): 2^(1 - bias) * (m/16) + Zero (e = 0, m = 0): 0.0 + + Args: + absmax: float32 tensor of per-block absolute maximum values. + bias: Exponent bias. Default 11 gives range [6.1e-5, 31.0]. + + Returns: + uint8 tensor of same shape as absmax. + """ + result = torch.zeros_like(absmax, dtype=torch.uint8) + nonzero = absmax > 0 + + # Compute exponent: floor(log2(absmax)) + log2_val = torch.log2(absmax[nonzero]) + e_unbiased = torch.floor(log2_val).to(torch.int32) + + # Clamp to representable range + e_biased = (e_unbiased + bias).clamp(0, 15) + + # Handle subnormals (e_biased <= 0 before clamping) + is_subnormal = (e_unbiased + bias) <= 0 + e_biased[is_subnormal] = 0 + + # Compute mantissa + abs_nz = absmax[nonzero] + # Normal: m = round((absmax / 2^e_unbiased - 1) * 16) + # Subnormal: m = round(absmax / 2^(1-bias) * 16) + mantissa = torch.zeros_like(abs_nz, dtype=torch.int32) + + normal_mask = ~is_subnormal + if normal_mask.any(): + e_ub_normal = e_unbiased[normal_mask] + scale = torch.exp2(e_ub_normal.float()) + m_float = (abs_nz[normal_mask] / scale - 1.0) * 16.0 + mantissa[normal_mask] = m_float.round().to(torch.int32).clamp(0, 15) + + if is_subnormal.any(): + subnormal_scale = 2.0 ** (1 - bias) + m_float = abs_nz[is_subnormal] / subnormal_scale * 16.0 + mantissa[is_subnormal] = m_float.round().to(torch.int32).clamp(0, 15) + + encoded = (e_biased << 4 | mantissa).to(torch.uint8) + result[nonzero] = encoded + return result + + +def decode_absmax_e4m4(encoded: Tensor, bias: int = 11) -> Tensor: + """Decode uint8 E4M4 absmax values to fp32. + + Args: + encoded: uint8 tensor of E4M4-encoded absmax values. + bias: Exponent bias (must match encoding). + + Returns: + float32 tensor of decoded absmax values. + """ + raw = encoded.to(torch.int32) + e = raw >> 4 + m = raw & 0xF + + # Normal: 2^(e - bias) * (1 + m/16) + # Subnormal: 2^(1 - bias) * (m/16) + is_subnormal = e == 0 + result = torch.zeros_like(encoded, dtype=torch.float32) + + if (~is_subnormal).any(): + e_normal = e[~is_subnormal].float() + m_normal = m[~is_subnormal].float() + result[~is_subnormal] = torch.exp2(e_normal - bias) * (1.0 + m_normal / 16.0) + + if is_subnormal.any(): + m_sub = m[is_subnormal].float() + result[is_subnormal] = (2.0 ** (1 - bias)) * (m_sub / 16.0) + + return result + + +def quantize_kbit( + A: Tensor, + k: int = 4, + codebook: Optional[Tensor] = None, + absmax_format: str = "e4m4", +) -> tuple[Tensor, Tensor, Tensor]: + """Quantize a tensor using k-bit blockwise quantization (blocksize=32). + + Uses warp-level CUDA primitives for efficient bit-plane packing. + + Args: + A: Input tensor. Supports float16, bfloat16, or float32. + k: Bit width (2, 3, 4, or 5). Defaults to 4. + codebook: Optional float32 codebook tensor with 2^k entries in [-1, 1], sorted ascending. + If None, uses a precomputed normal-float codebook. + absmax_format: Format for absmax storage. "e4m4" (default, uint8) or "fp32". + + Returns: + Tuple of (packed, absmax, codebook): + - packed: int32 tensor of bit-plane packed quantized values. + - absmax: Tensor of per-block absolute maximum values (float32 or uint8). + - codebook: The codebook tensor used (useful when auto-generated). + """ + if codebook is None: + codebook = create_normal_float_codebook(k, device=A.device) + else: + codebook = codebook.to(device=A.device, dtype=torch.float32) + + A_flat = A.contiguous().view(-1) + packed, absmax = torch.ops.bitsandbytes.quantize_kbit(A_flat, codebook, k) + + if absmax_format == "e4m4": + absmax = encode_absmax_e4m4(absmax) + + return packed, absmax, codebook + + +def dequantize_kbit( + packed: Tensor, + absmax: Tensor, + codebook: Tensor, + k: int, + n: int, + dtype: torch.dtype = torch.float16, + out: Optional[Tensor] = None, +) -> Tensor: + """Dequantize a k-bit blockwise quantized tensor. + + Args: + packed: int32 tensor of bit-plane packed values (from quantize_kbit). + absmax: Tensor of per-block absmax values (from quantize_kbit). + Supports float32 or uint8 (E4M4 format). + codebook: float32 codebook tensor with 2^k entries. + k: Bit width (2, 3, 4, or 5). + n: Number of original elements. + dtype: Output dtype. Defaults to float16. + out: Optional pre-allocated output tensor for CUDA graph compatibility. + Must have at least ceil(n/32)*32 elements and matching dtype. + + Returns: + Dequantized tensor of shape (n,) with the given dtype. + """ + num_blocks = -(n // -32) + padded_n = num_blocks * 32 + + if out is not None: + if out.numel() < padded_n: + raise ValueError(f"out tensor has {out.numel()} elements, need at least {padded_n}") + if out.dtype != dtype: + raise ValueError(f"out dtype {out.dtype} does not match requested dtype {dtype}") + torch.ops.bitsandbytes.dequantize_kbit_(packed, codebook, absmax, k, n, dtype, out) + return out[:n] + + result = torch.ops.bitsandbytes.dequantize_kbit(packed, codebook, absmax, k, n, dtype) + return result[:n] + + @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def quantize( A: Tensor, diff --git a/csrc/kernels.cu b/csrc/kernels.cu index da63bf6c6..55ea54995 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2601,3 +2601,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1) + +// K-bit kernel definitions moved to ops.cu to avoid RDC device linking issues. diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index e7a1282bc..1bf2ec287 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -125,4 +125,7 @@ __global__ void kgemm_4bit_inference_naive( template __global__ void kfunc(T* A, T* B, T value, long n); +// K-bit kernel definitions live in ops.cu (not kernels.cu) to keep kernel +// and launch wrapper in the same compilation unit. No declarations needed here. + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 875c82b1c..a5cb96ed1 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -645,3 +645,225 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); template void percentileClipping(float* g, float* gnorm_vec, int step, const int n); template void percentileClipping(half* g, float* gnorm_vec, int step, const int n); + +// =========================================================================== +// K-bit blockwise quantization/dequantization (blocksize=32, K=2..5) +// +// Kernel definitions and launch wrappers in the same compilation unit +// to avoid RDC device linking issues with template instantiations. +// =========================================================================== + +// ---- Device helpers ---- + +__device__ __forceinline__ float warp_reduce_absmax_kbit(float val) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); + return __shfl_sync(0xFFFFFFFF, val, 0); +} + +template __device__ __forceinline__ void pack_kbit_warp(unsigned char qval, unsigned int* packed_words) { +#pragma unroll + for (int bit = 0; bit < K; bit++) + packed_words[bit] = __ballot_sync(0xFFFFFFFF, (qval >> bit) & 1); +} + +template +__device__ __forceinline__ unsigned char unpack_kbit_warp(const unsigned int* packed_words, int lane_id) { + unsigned char val = 0; +#pragma unroll + for (int bit = 0; bit < K; bit++) + val |= ((packed_words[bit] >> lane_id) & 1) << bit; + return val; +} + +// ---- Stage 4: Full quantize kernel ---- + +template +__global__ void kQuantizeBlockwise_kbit( + const float* __restrict__ codebook, const T* __restrict__ A, float* __restrict__ absmax, + unsigned int* __restrict__ packed_out, const int n +) { + const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int lane_id = threadIdx.x % 32; + const int block_start = warp_id * 32; + if (block_start >= n) + return; + float val = (block_start + lane_id < n) ? (float)A[block_start + lane_id] : 0.0f; + float amax = warp_reduce_absmax_kbit(fabsf(val)); + float amax_safe = fmaxf(amax, 1e-8f); + if (lane_id == 0) + absmax[warp_id] = amax; + float normalized = val / amax_safe; + float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0.0f; + unsigned char best_idx = 0; + float best_dist = 1e10f; +#pragma unroll + for (int i = 0; i < (1 << K); i++) { + float cb_val = __shfl_sync(0xFFFFFFFF, cb, i); + float dist = fabsf(normalized - cb_val); + bool closer = (dist < best_dist); + best_dist = closer ? dist : best_dist; + best_idx = closer ? (unsigned char)i : best_idx; + } + unsigned int packed[K]; + pack_kbit_warp(best_idx, packed); + if (lane_id < K) + packed_out[warp_id * K + lane_id] = packed[lane_id]; +} + +// ---- E4M4 absmax decode ---- +// uint8 -> float: E4M4 format with configurable bias and IEEE-style subnormals. +// Normal (e > 0): 2^(e - BIAS) * (1 + m/16) +// Subnormal (e = 0): 2^(1 - BIAS) * (m/16) +// Zero (e = 0, m = 0): 0.0 +constexpr int E4M4_BIAS = 11; + +__device__ __forceinline__ float decode_e4m4_absmax(unsigned char raw) { + if (raw == 0) + return 0.0f; + int e = raw >> 4; + int m = raw & 0xF; + if (e == 0) { + // Subnormal (extremely rare in practice): 2^(1-BIAS) * m/16 + return ldexpf((float)m, 1 - E4M4_BIAS - 4); + } + // Normal: construct IEEE 754 float directly via bit manipulation. + // Target: 2^(e - BIAS) * (1 + m/16) + // IEEE 754: exponent_field = (e - BIAS) + 127, mantissa_field = m << 19 + unsigned int ieee = (unsigned int)(e - E4M4_BIAS + 127) << 23 | (unsigned int)m << 19; + return __uint_as_float(ieee); +} + +// Template helper: convert ABSMAX_T to float. +// Specialization for unsigned char uses E4M4 decode. +template __device__ __forceinline__ float load_absmax(const ABSMAX_T* absmax, int idx) { + return (float)absmax[idx]; +} + +template <> __device__ __forceinline__ float load_absmax(const unsigned char* absmax, int idx) { + return decode_e4m4_absmax(absmax[idx]); +} + +// ---- Stage 5: Full dequantize kernel ---- + +// Vectorized version: each warp processes BLOCKS_PER_WARP quant blocks, +// amortizing codebook load across multiple blocks. +// Templated on T (output type) and ABSMAX_T (absmax format). +template +__global__ void kDequantizeBlockwise_kbit_vec( + const unsigned int* __restrict__ packed_in, const float* __restrict__ codebook, const ABSMAX_T* __restrict__ absmax, + T* __restrict__ out, const int n +) { + const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int lane_id = threadIdx.x % 32; + const int base_block = warp_id * BLOCKS_PER_WARP; + + if (base_block * 32 >= n) + return; + + // Load codebook into lane registers (one-time, amortized across BLOCKS_PER_WARP blocks) + float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0.0f; + +#pragma unroll + for (int b = 0; b < BLOCKS_PER_WARP; b++) { + const int block_id = base_block + b; + const int block_start = block_id * 32; + if (block_start >= n) + break; + + float amax = load_absmax(absmax, block_id); + unsigned int packed[K]; +#pragma unroll + for (int bit = 0; bit < K; bit++) { + unsigned int word = (lane_id == bit) ? packed_in[block_id * K + bit] : 0; + packed[bit] = __shfl_sync(0xFFFFFFFF, word, bit); + } + unsigned char idx = unpack_kbit_warp(packed, lane_id); + float val = __shfl_sync(0xFFFFFFFF, cb, idx) * amax; + + if (block_start + lane_id < n) + out[block_start + lane_id] = (T)val; + } +} + +// ---- Launch wrappers ---- + +#define KBIT_WARPS_PER_BLOCK 8 +#define KBIT_THREADS_PER_BLOCK (KBIT_WARPS_PER_BLOCK * 32) // 256 + +// ---- Production kernel launchers (Stage 4-5) ---- + +template +void quantizeBlockwise_kbit(const float* codebook, const T* A, float* absmax, unsigned int* packed_out, int n) { + int num_blocks_quant = (n + 31) / 32; + int num_cuda_blocks = (num_blocks_quant + KBIT_WARPS_PER_BLOCK - 1) / KBIT_WARPS_PER_BLOCK; + kQuantizeBlockwise_kbit<<>>(codebook, A, absmax, packed_out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +// Generic dequant launcher: supports all output types and absmax formats. +template +void dequantizeBlockwise_kbit( + const unsigned int* packed_in, const float* codebook, const ABSMAX_T* absmax, T* out, int n, cudaStream_t stream +) { + constexpr int BPW = 4; // blocks per warp + int num_blocks_quant = (n + 31) / 32; + int num_warps = (num_blocks_quant + BPW - 1) / BPW; + int num_cuda_blocks = (num_warps + KBIT_WARPS_PER_BLOCK - 1) / KBIT_WARPS_PER_BLOCK; + kDequantizeBlockwise_kbit_vec + <<>>(packed_in, codebook, absmax, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +// ---- Template instantiations ---- + +#define INSTANTIATE_KBIT_QUANT(T, K) \ + template void quantizeBlockwise_kbit(const float*, const T*, float*, unsigned int*, int); + +INSTANTIATE_KBIT_QUANT(half, 2) +INSTANTIATE_KBIT_QUANT(half, 3) +INSTANTIATE_KBIT_QUANT(half, 4) +INSTANTIATE_KBIT_QUANT(half, 5) +INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 2) +INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 3) +INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 4) +INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 5) +INSTANTIATE_KBIT_QUANT(float, 2) +INSTANTIATE_KBIT_QUANT(float, 3) +INSTANTIATE_KBIT_QUANT(float, 4) +INSTANTIATE_KBIT_QUANT(float, 5) + +// Dequant instantiations: all output types × absmax types × K values +#define INSTANTIATE_KBIT_DEQUANT(T, K, ABSMAX_T) \ + template void dequantizeBlockwise_kbit( \ + const unsigned int*, const float*, const ABSMAX_T*, T*, int, cudaStream_t \ + ); + +// uint8 E4M4 absmax (default) +INSTANTIATE_KBIT_DEQUANT(half, 2, unsigned char) +INSTANTIATE_KBIT_DEQUANT(half, 3, unsigned char) +INSTANTIATE_KBIT_DEQUANT(half, 4, unsigned char) +INSTANTIATE_KBIT_DEQUANT(half, 5, unsigned char) +INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 2, unsigned char) +INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 3, unsigned char) +INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 4, unsigned char) +INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 5, unsigned char) +INSTANTIATE_KBIT_DEQUANT(float, 2, unsigned char) +INSTANTIATE_KBIT_DEQUANT(float, 3, unsigned char) +INSTANTIATE_KBIT_DEQUANT(float, 4, unsigned char) +INSTANTIATE_KBIT_DEQUANT(float, 5, unsigned char) + +// fp16 absmax (option) +INSTANTIATE_KBIT_DEQUANT(half, 2, half) +INSTANTIATE_KBIT_DEQUANT(half, 3, half) +INSTANTIATE_KBIT_DEQUANT(half, 4, half) +INSTANTIATE_KBIT_DEQUANT(half, 5, half) +INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 2, half) +INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 3, half) +INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 4, half) +INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 5, half) +INSTANTIATE_KBIT_DEQUANT(float, 2, half) +INSTANTIATE_KBIT_DEQUANT(float, 3, half) +INSTANTIATE_KBIT_DEQUANT(float, 4, half) +INSTANTIATE_KBIT_DEQUANT(float, 5, half) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 340f06145..615523224 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -382,7 +382,77 @@ void gemv_4bit_inference_fp32( gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } -#endif +#endif // BUILD_XPU + +// =========================================================================== +// K-bit blockwise quantization/dequantization wrappers (unmangled) +// =========================================================================== +#if BUILD_CUDA || BUILD_HIP + +// Forward declarations of ops.cu template functions +template void quantizeBlockwise_kbit(const float*, const T*, float*, unsigned int*, int); +template +void dequantizeBlockwise_kbit(const unsigned int*, const float*, const ABSMAX_T*, T*, int, cudaStream_t); + +// Unmangled quantize wrappers +#define MAKE_KBIT_QUANT(tname, T, K) \ + void quantize_kbit_##tname##_k##K( \ + const float* codebook, const T* A, float* absmax, unsigned int* packed_out, int n \ + ) { \ + quantizeBlockwise_kbit(codebook, A, absmax, packed_out, n); \ + } + +MAKE_KBIT_QUANT(fp16, half, 2) +MAKE_KBIT_QUANT(fp16, half, 3) +MAKE_KBIT_QUANT(fp16, half, 4) +MAKE_KBIT_QUANT(fp16, half, 5) +MAKE_KBIT_QUANT(bf16, __nv_bfloat16, 2) +MAKE_KBIT_QUANT(bf16, __nv_bfloat16, 3) +MAKE_KBIT_QUANT(bf16, __nv_bfloat16, 4) +MAKE_KBIT_QUANT(bf16, __nv_bfloat16, 5) +MAKE_KBIT_QUANT(fp32, float, 2) +MAKE_KBIT_QUANT(fp32, float, 3) +MAKE_KBIT_QUANT(fp32, float, 4) +MAKE_KBIT_QUANT(fp32, float, 5) + +// Unmangled dequant wrappers: output type × absmax type × K +#define MAKE_KBIT_DEQUANT(tname, T, aname, ABSMAX_T, K) \ + void dequantize_kbit_##tname##_##aname##_k##K( \ + const unsigned int* packed_in, const float* codebook, const ABSMAX_T* absmax, T* out, int n, \ + cudaStream_t stream \ + ) { \ + dequantizeBlockwise_kbit(packed_in, codebook, absmax, out, n, stream); \ + } + +// uint8 E4M4 absmax (default) - all output types +MAKE_KBIT_DEQUANT(fp16, half, u8abs, unsigned char, 2) +MAKE_KBIT_DEQUANT(fp16, half, u8abs, unsigned char, 3) +MAKE_KBIT_DEQUANT(fp16, half, u8abs, unsigned char, 4) +MAKE_KBIT_DEQUANT(fp16, half, u8abs, unsigned char, 5) +MAKE_KBIT_DEQUANT(bf16, __nv_bfloat16, u8abs, unsigned char, 2) +MAKE_KBIT_DEQUANT(bf16, __nv_bfloat16, u8abs, unsigned char, 3) +MAKE_KBIT_DEQUANT(bf16, __nv_bfloat16, u8abs, unsigned char, 4) +MAKE_KBIT_DEQUANT(bf16, __nv_bfloat16, u8abs, unsigned char, 5) +MAKE_KBIT_DEQUANT(fp32, float, u8abs, unsigned char, 2) +MAKE_KBIT_DEQUANT(fp32, float, u8abs, unsigned char, 3) +MAKE_KBIT_DEQUANT(fp32, float, u8abs, unsigned char, 4) +MAKE_KBIT_DEQUANT(fp32, float, u8abs, unsigned char, 5) + +// fp16 absmax (option) - all output types +MAKE_KBIT_DEQUANT(fp16, half, fp16abs, half, 2) +MAKE_KBIT_DEQUANT(fp16, half, fp16abs, half, 3) +MAKE_KBIT_DEQUANT(fp16, half, fp16abs, half, 4) +MAKE_KBIT_DEQUANT(fp16, half, fp16abs, half, 5) +MAKE_KBIT_DEQUANT(bf16, __nv_bfloat16, fp16abs, half, 2) +MAKE_KBIT_DEQUANT(bf16, __nv_bfloat16, fp16abs, half, 3) +MAKE_KBIT_DEQUANT(bf16, __nv_bfloat16, fp16abs, half, 4) +MAKE_KBIT_DEQUANT(bf16, __nv_bfloat16, fp16abs, half, 5) +MAKE_KBIT_DEQUANT(fp32, float, fp16abs, half, 2) +MAKE_KBIT_DEQUANT(fp32, float, fp16abs, half, 3) +MAKE_KBIT_DEQUANT(fp32, float, fp16abs, half, 4) +MAKE_KBIT_DEQUANT(fp32, float, fp16abs, half, 5) + +#endif // BUILD_CUDA || BUILD_HIP (kbit unmangled) extern "C" { #if BUILD_CUDA || BUILD_HIP @@ -887,5 +957,70 @@ bool has_avx512f_cpu() { return has_avx512f(); } #if defined(__AVX512BF16__) bool has_avx512bf16_cpu() { return has_avx512bf16(); } #endif +#endif + +// =========================================================================== +// K-bit blockwise quantization/dequantization (extern "C" exports) +// =========================================================================== +#if BUILD_CUDA || BUILD_HIP + +// Production kernels (Stage 4-5) - quantize only +#define MAKE_CKBIT(tname, T, K) \ + void cquantize_kbit_##tname##_k##K( \ + const float* codebook, const T* A, float* absmax, unsigned int* packed_out, int n \ + ) { \ + quantize_kbit_##tname##_k##K(codebook, A, absmax, packed_out, n); \ + } + +MAKE_CKBIT(fp16, half, 2) +MAKE_CKBIT(fp16, half, 3) +MAKE_CKBIT(fp16, half, 4) +MAKE_CKBIT(fp16, half, 5) +MAKE_CKBIT(bf16, __nv_bfloat16, 2) +MAKE_CKBIT(bf16, __nv_bfloat16, 3) +MAKE_CKBIT(bf16, __nv_bfloat16, 4) +MAKE_CKBIT(bf16, __nv_bfloat16, 5) +MAKE_CKBIT(fp32, float, 2) +MAKE_CKBIT(fp32, float, 3) +MAKE_CKBIT(fp32, float, 4) +MAKE_CKBIT(fp32, float, 5) + +// Dequant extern C wrappers: output type × absmax type × K +#define MAKE_CKBIT_DEQUANT(tname, T, aname, ABSMAX_T, K) \ + void cdequantize_kbit_##tname##_##aname##_k##K( \ + const unsigned int* packed_in, const float* codebook, const ABSMAX_T* absmax, T* out, int n, \ + cudaStream_t stream \ + ) { \ + dequantize_kbit_##tname##_##aname##_k##K(packed_in, codebook, absmax, out, n, stream); \ + } + +// uint8 E4M4 absmax - all output types +MAKE_CKBIT_DEQUANT(fp16, half, u8abs, unsigned char, 2) +MAKE_CKBIT_DEQUANT(fp16, half, u8abs, unsigned char, 3) +MAKE_CKBIT_DEQUANT(fp16, half, u8abs, unsigned char, 4) +MAKE_CKBIT_DEQUANT(fp16, half, u8abs, unsigned char, 5) +MAKE_CKBIT_DEQUANT(bf16, __nv_bfloat16, u8abs, unsigned char, 2) +MAKE_CKBIT_DEQUANT(bf16, __nv_bfloat16, u8abs, unsigned char, 3) +MAKE_CKBIT_DEQUANT(bf16, __nv_bfloat16, u8abs, unsigned char, 4) +MAKE_CKBIT_DEQUANT(bf16, __nv_bfloat16, u8abs, unsigned char, 5) +MAKE_CKBIT_DEQUANT(fp32, float, u8abs, unsigned char, 2) +MAKE_CKBIT_DEQUANT(fp32, float, u8abs, unsigned char, 3) +MAKE_CKBIT_DEQUANT(fp32, float, u8abs, unsigned char, 4) +MAKE_CKBIT_DEQUANT(fp32, float, u8abs, unsigned char, 5) + +// fp16 absmax - all output types +MAKE_CKBIT_DEQUANT(fp16, half, fp16abs, half, 2) +MAKE_CKBIT_DEQUANT(fp16, half, fp16abs, half, 3) +MAKE_CKBIT_DEQUANT(fp16, half, fp16abs, half, 4) +MAKE_CKBIT_DEQUANT(fp16, half, fp16abs, half, 5) +MAKE_CKBIT_DEQUANT(bf16, __nv_bfloat16, fp16abs, half, 2) +MAKE_CKBIT_DEQUANT(bf16, __nv_bfloat16, fp16abs, half, 3) +MAKE_CKBIT_DEQUANT(bf16, __nv_bfloat16, fp16abs, half, 4) +MAKE_CKBIT_DEQUANT(bf16, __nv_bfloat16, fp16abs, half, 5) +MAKE_CKBIT_DEQUANT(fp32, float, fp16abs, half, 2) +MAKE_CKBIT_DEQUANT(fp32, float, fp16abs, half, 3) +MAKE_CKBIT_DEQUANT(fp32, float, fp16abs, half, 4) +MAKE_CKBIT_DEQUANT(fp32, float, fp16abs, half, 5) + #endif } diff --git a/tests/test_kbit_quantization.py b/tests/test_kbit_quantization.py new file mode 100644 index 000000000..5b145cc4d --- /dev/null +++ b/tests/test_kbit_quantization.py @@ -0,0 +1,1466 @@ +""" +Tests for k-bit quantization (K=2..5, blocksize=32). + +Staged implementation following cuda-spec-additions.md: + Stage 0: Pure Python reference + Stage 1-3: Temporary CUDA test kernels (pack/unpack, memory format, codebook lookup) + Stage 4: Full quantize kernel + Stage 5: Full dequantize kernel + Stage 6: Round-trip error analysis + Stage 7: Cross-validation against existing NF4 + Stage 8: Performance benchmarking +""" + +import ctypes as ct +import math + +import pytest +from scipy.stats import norm +import torch + +# --------------------------------------------------------------------------- +# Codebook generation +# --------------------------------------------------------------------------- + + +def create_normal_float_codebook(k: int) -> torch.Tensor: + """Create a 2^k-entry normal-float codebook (quantiles of N(0,1), normalized to [-1, 1]). + + For k bits we have 2^k reconstruction levels placed at the expected values + of N(0,1) within 2^k equiprobable bins. The result is sorted ascending + and normalized so the largest magnitude is 1.0. + + For k=4 this is conceptually the same as the NF4 datatype (with minor + numerical differences due to the asymmetric extra-value trick in the + existing bitsandbytes NF4). + """ + n_levels = 1 << k + # Midpoints of n_levels equiprobable bins + quantiles = torch.linspace(0.5 / n_levels, 1.0 - 0.5 / n_levels, n_levels) + values = torch.tensor(norm.ppf(quantiles.numpy()), dtype=torch.float32) + # Normalize to [-1, 1] + values = values / values.abs().max() + return values + + +# --------------------------------------------------------------------------- +# Stage 0: Pure Python reference implementation +# --------------------------------------------------------------------------- + +BLOCKSIZE = 32 + + +def quantize_kbit_ref( + A: torch.Tensor, + codebook: torch.Tensor, + blocksize: int = BLOCKSIZE, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pure-PyTorch k-bit blockwise quantization (reference, not optimized). + + Args: + A: Input tensor (any shape, will be flattened). + codebook: 1-D float tensor of 2^k reconstruction levels, sorted ascending. + blocksize: Number of elements per quantization block (must be 32). + + Returns: + indices: uint8 tensor of shape (n,) with values in [0, 2^k). + absmax: float32 tensor of shape (num_blocks,). + """ + assert blocksize == 32, "k-bit reference only supports blocksize=32" + A_flat = A.float().reshape(-1) + n = A_flat.numel() + # Pad to multiple of blocksize + pad = (blocksize - n % blocksize) % blocksize + if pad > 0: + A_flat = torch.nn.functional.pad(A_flat, (0, pad)) + n_padded = A_flat.numel() + num_blocks = n_padded // blocksize + + blocks = A_flat.reshape(num_blocks, blocksize) + absmax = blocks.abs().max(dim=1).values # (num_blocks,) + # Avoid division by zero + absmax_safe = absmax.clamp(min=1e-8) + # Normalize to [-1, 1] + normalized = blocks / absmax_safe.unsqueeze(1) + + # Find nearest codebook entry for each element (brute force) + # codebook: (2^k,), normalized: (num_blocks, blocksize) + cb = codebook.float().unsqueeze(0).unsqueeze(0) # (1, 1, 2^k) + norm_exp = normalized.unsqueeze(2) # (num_blocks, blocksize, 1) + distances = (norm_exp - cb).abs() # (num_blocks, blocksize, 2^k) + indices = distances.argmin(dim=2).to(torch.uint8) # (num_blocks, blocksize) + + # Flatten and trim padding + indices = indices.reshape(-1)[:n] + return indices, absmax + + +def dequantize_kbit_ref( + indices: torch.Tensor, + absmax: torch.Tensor, + codebook: torch.Tensor, + dtype: torch.dtype = torch.float32, + blocksize: int = BLOCKSIZE, +) -> torch.Tensor: + """Pure-PyTorch k-bit blockwise dequantization (reference). + + Args: + indices: uint8 tensor of shape (n,) with values in [0, 2^k). + absmax: float32 tensor of shape (num_blocks,). + codebook: 1-D float tensor of 2^k reconstruction levels. + dtype: Output dtype. + blocksize: Must be 32. + + Returns: + Dequantized tensor of shape (n,) with the given dtype. + """ + assert blocksize == 32, "k-bit reference only supports blocksize=32" + n = indices.numel() + # Pad indices to multiple of blocksize + pad = (blocksize - n % blocksize) % blocksize + if pad > 0: + indices = torch.nn.functional.pad(indices.long(), (0, pad)) + n_padded = indices.numel() + num_blocks = n_padded // blocksize + + # Lookup codebook values + cb_values = codebook.float()[indices.long()] # (n_padded,) + cb_values = cb_values.reshape(num_blocks, blocksize) + + # Scale by absmax + out = cb_values * absmax.unsqueeze(1) + + # Flatten and trim + out = out.reshape(-1)[:n] + return out.to(dtype) + + +# --------------------------------------------------------------------------- +# Bit-plane packing/unpacking (Python reference for testing CUDA) +# --------------------------------------------------------------------------- + + +def pack_kbit_ref(indices: torch.Tensor, k: int, blocksize: int = BLOCKSIZE) -> torch.Tensor: + """Pack k-bit indices into bit-plane uint32 words (Python reference). + + For each block of 32 elements, produces k uint32 words where word j + contains bit j of all 32 elements (bit-plane layout). + + Args: + indices: uint8 tensor of shape (n,). + k: Bit width. + + Returns: + packed: uint32 tensor of shape (num_blocks * k,). + """ + n = indices.numel() + pad = (blocksize - n % blocksize) % blocksize + if pad > 0: + indices = torch.nn.functional.pad(indices.int(), (0, pad)) + n_padded = indices.numel() + num_blocks = n_padded // blocksize + blocks = indices.int().reshape(num_blocks, blocksize) + + packed_words = [] + for b in range(num_blocks): + for bit in range(k): + word = 0 + for i in range(blocksize): + word |= ((int(blocks[b, i]) >> bit) & 1) << i + # Convert to signed int32 (reinterpret high bit as sign) + if word >= (1 << 31): + word -= 1 << 32 + packed_words.append(word) + return torch.tensor(packed_words, dtype=torch.int32) + + +def unpack_kbit_ref(packed: torch.Tensor, k: int, n: int, blocksize: int = BLOCKSIZE) -> torch.Tensor: + """Unpack bit-plane uint32 words back to k-bit indices (Python reference). + + Args: + packed: int32 tensor of shape (num_blocks * k,). + k: Bit width. + n: Number of original elements. + + Returns: + indices: uint8 tensor of shape (n,). + """ + num_blocks = packed.numel() // k + indices = [] + for b in range(num_blocks): + words_raw = packed[b * k : b * k + k].tolist() + # Convert signed int32 back to unsigned + words = [(w & 0xFFFFFFFF) for w in words_raw] + for i in range(blocksize): + val = 0 + for bit in range(k): + val |= ((words[bit] >> i) & 1) << bit + indices.append(val) + return torch.tensor(indices[:n], dtype=torch.uint8) + + +# =========================================================================== +# Tests +# =========================================================================== + + +class TestCodebook: + """Test codebook generation.""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_codebook_size(self, k): + cb = create_normal_float_codebook(k) + assert cb.numel() == (1 << k) + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_codebook_sorted(self, k): + cb = create_normal_float_codebook(k) + assert (cb[1:] >= cb[:-1]).all() + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_codebook_range(self, k): + cb = create_normal_float_codebook(k) + assert cb.abs().max().item() == pytest.approx(1.0, abs=1e-6) + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_codebook_symmetric_ish(self, k): + """Codebook should be roughly symmetric around 0.""" + cb = create_normal_float_codebook(k) + assert abs(cb.mean().item()) < 0.1 # not exactly 0 for odd counts + + +class TestQuantizeRef: + """Stage 0: Test the pure Python reference implementation.""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_round_trip_basic(self, k): + """Quantize then dequantize; output should be close to input.""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k) + A = torch.randn(1024) + indices, absmax = quantize_kbit_ref(A, cb) + recovered = dequantize_kbit_ref(indices, absmax, cb) + # Check shapes + assert indices.shape == (1024,) + assert absmax.shape == (1024 // 32,) + assert recovered.shape == (1024,) + # MSE should decrease with more bits + mse = ((A - recovered) ** 2).mean().item() + assert mse < 1.0 # very loose sanity check + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_mse_decreases_with_bits(self, k): + """More bits should give lower MSE.""" + torch.manual_seed(42) + A = torch.randn(4096) + mses = {} + for ki in [2, 3, 4, 5]: + cb = create_normal_float_codebook(ki) + indices, absmax = quantize_kbit_ref(A, cb) + recovered = dequantize_kbit_ref(indices, absmax, cb) + mses[ki] = ((A - recovered) ** 2).mean().item() + # MSE should be monotonically decreasing (or very close) + for ki in [3, 4, 5]: + assert mses[ki] <= mses[ki - 1] * 1.05 # 5% tolerance for noise + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_indices_in_range(self, k): + cb = create_normal_float_codebook(k) + A = torch.randn(256) + indices, _ = quantize_kbit_ref(A, cb) + assert indices.max().item() < (1 << k) + assert indices.min().item() >= 0 + + @pytest.mark.parametrize("n", [1, 31, 32, 33, 63, 64, 65, 1000]) + def test_various_sizes(self, n): + """Non-aligned sizes should work.""" + k = 3 + cb = create_normal_float_codebook(k) + A = torch.randn(n) + indices, absmax = quantize_kbit_ref(A, cb) + assert indices.shape == (n,) + num_blocks = math.ceil(n / 32) + assert absmax.shape == (num_blocks,) + recovered = dequantize_kbit_ref(indices, absmax, cb) + assert recovered.shape == (n,) + + def test_all_zeros(self): + """All-zero input: absmax should be clamped, indices should point to ~0.""" + k = 3 + cb = create_normal_float_codebook(k) + A = torch.zeros(64) + indices, absmax = quantize_kbit_ref(A, cb) + recovered = dequantize_kbit_ref(indices, absmax, cb) + assert recovered.abs().max().item() < 1e-4 + + def test_absmax_correctness(self): + """Absmax should match manual per-block computation.""" + k = 3 + cb = create_normal_float_codebook(k) + A = torch.randn(128) + _, absmax = quantize_kbit_ref(A, cb) + expected = A.reshape(-1, 32).abs().max(dim=1).values + assert torch.allclose(absmax, expected) + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_analytical_error_bound(self, k): + """Max per-element error should be bounded by max_gap/2 * absmax.""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k) + A = torch.randn(4096) + indices, absmax = quantize_kbit_ref(A, cb) + recovered = dequantize_kbit_ref(indices, absmax, cb) + errors = (A - recovered).abs() + + # Max gap in codebook + gaps = cb[1:] - cb[:-1] + max_gap = gaps.max().item() + + # Per block, error <= max_gap/2 * absmax_of_block + A_blocks = A.reshape(-1, 32) + err_blocks = errors.reshape(-1, 32) + for i in range(A_blocks.shape[0]): + block_bound = max_gap / 2 * absmax[i].item() + block_max_err = err_blocks[i].max().item() + assert block_max_err <= block_bound + 1e-6, f"Block {i}: max_err={block_max_err}, bound={block_bound}" + + +class TestPackUnpackRef: + """Test the Python reference bit-plane packing.""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_round_trip(self, k): + n = 128 + indices = torch.randint(0, 1 << k, (n,), dtype=torch.uint8) + packed = pack_kbit_ref(indices, k) + recovered = unpack_kbit_ref(packed, k, n) + assert (indices == recovered).all() + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_packed_size(self, k): + n = 128 + indices = torch.randint(0, 1 << k, (n,), dtype=torch.uint8) + packed = pack_kbit_ref(indices, k) + num_blocks = math.ceil(n / 32) + assert packed.numel() == num_blocks * k + + @pytest.mark.parametrize("n", [1, 31, 32, 33, 64, 65]) + def test_non_aligned_sizes(self, n): + k = 3 + indices = torch.randint(0, 1 << k, (n,), dtype=torch.uint8) + packed = pack_kbit_ref(indices, k) + recovered = unpack_kbit_ref(packed, k, n) + assert (indices == recovered).all() + + def test_known_pattern_k3(self): + """Verify a known bit pattern for K=3.""" + # 32 elements: indices 0,1,2,3,4,5,6,7 repeated 4 times + indices = torch.tensor(list(range(8)) * 4, dtype=torch.uint8) + assert indices.numel() == 32 + packed = pack_kbit_ref(indices, k=3) + assert packed.numel() == 3 # 1 block * 3 words + + # Bit 0 of each element: 0,1,0,1,0,1,0,1, repeated + # bit0: [0,1,0,1,0,1,0,1, 0,1,0,1,0,1,0,1, 0,1,0,1,0,1,0,1, 0,1,0,1,0,1,0,1] + expected_w0 = 0 + for i in range(32): + expected_w0 |= ((indices[i].item() >> 0) & 1) << i + assert (packed[0].item() & 0xFFFFFFFF) == (expected_w0 & 0xFFFFFFFF) + + # Verify round-trip + recovered = unpack_kbit_ref(packed, k=3, n=32) + assert (indices == recovered).all() + + +# =========================================================================== +# CUDA helpers -- ctypes wrappers for the C interface +# =========================================================================== + + +def _get_lib(): + """Load the bitsandbytes native library.""" + from bitsandbytes.cextension import lib + + return lib + + +def _get_ptr(t): + """Get a ctypes-compatible pointer from a CUDA tensor.""" + return ct.c_void_p(t.data_ptr()) + + +def _dtype_to_tname(dtype): + """Map torch dtype to C type name suffix.""" + return {torch.float16: "fp16", torch.bfloat16: "bf16", torch.float32: "fp32"}[dtype] + + +def _cuda_quantize_kbit(A, codebook, k): + """Call cquantize_kbit_{tname}_k{k}. Returns (packed, absmax).""" + lib = _get_lib() + n = A.numel() + num_blocks = (n + 31) // 32 + tname = _dtype_to_tname(A.dtype) + packed = torch.zeros(num_blocks * k + k, dtype=torch.int32, device=A.device) + absmax = torch.zeros(num_blocks + 1, dtype=torch.float32, device=A.device) # +1 for padding + fn = getattr(lib, f"cquantize_kbit_{tname}_k{k}") + fn(_get_ptr(codebook), _get_ptr(A), _get_ptr(absmax), _get_ptr(packed), ct.c_int(n)) + torch.cuda.synchronize() + return packed[: num_blocks * k], absmax[:num_blocks] + + +def _cuda_dequantize_kbit(packed, codebook, absmax, k, n, dtype=torch.float16): + """Call cdequantize_kbit_{tname}_{aname}_k{k} with native output type. + + If absmax is float32, encode to E4M4 first. + """ + from bitsandbytes.functional import encode_absmax_e4m4 + + lib = _get_lib() + num_blocks = (n + 31) // 32 + # Pad packed buffer + packed_padded = torch.zeros(num_blocks * k + k, dtype=torch.int32, device=packed.device) + packed_padded[: packed.numel()] = packed + # Handle absmax encoding + if absmax.dtype == torch.float32: + absmax_enc = encode_absmax_e4m4(absmax) + else: + absmax_enc = absmax + aname = {torch.uint8: "u8abs", torch.float16: "fp16abs"}[absmax_enc.dtype] + absmax_padded = torch.zeros(num_blocks + 1, dtype=absmax_enc.dtype, device=packed.device) + absmax_padded[: absmax_enc.numel()] = absmax_enc + # Native output type + tname = _dtype_to_tname(dtype) + out = torch.zeros(num_blocks * 32, dtype=dtype, device=packed.device) + fn = getattr(lib, f"cdequantize_kbit_{tname}_{aname}_k{k}") + fn( + _get_ptr(packed_padded), + _get_ptr(codebook), + _get_ptr(absmax_padded), + _get_ptr(out), + ct.c_int(n), + ct.c_void_p(0), + ) + torch.cuda.synchronize() + return out[:n] + + +def _cuda_dequantize_kbit_prepped(packed_padded, codebook, absmax_u8_padded, k, n, out): + """Direct kernel call for benchmarks -- no encoding, no allocation. + + Caller must provide pre-padded packed/absmax and pre-allocated output. + """ + lib = _get_lib() + tname = _dtype_to_tname(out.dtype) + fn = getattr(lib, f"cdequantize_kbit_{tname}_u8abs_k{k}") + fn( + _get_ptr(packed_padded), + _get_ptr(codebook), + _get_ptr(absmax_u8_padded), + _get_ptr(out), + ct.c_int(n), + ct.c_void_p(0), + ) + + +# =========================================================================== +# CUDA Tests +# =========================================================================== + +requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +@requires_cuda +class TestStage4QuantizeCUDA: + """Stage 4: Full quantize kernel.""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_absmax_correctness(self, k): + """CUDA absmax should match manual per-block computation.""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(1024, dtype=torch.float16, device="cuda") + _, absmax = _cuda_quantize_kbit(A, cb, k) + expected = A.float().reshape(-1, 32).abs().max(dim=1).values + assert torch.allclose(absmax, expected, atol=1e-4), f"max diff: {(absmax - expected).abs().max()}" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_all_dtypes(self, k, dtype): + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(128, dtype=dtype, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + assert packed.numel() == (128 // 32) * k + assert absmax.numel() == 128 // 32 + + @pytest.mark.parametrize("n", [32, 64, 33, 1, 1000]) + def test_various_sizes(self, n): + k = 3 + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(n, dtype=torch.float16, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + num_blocks = (n + 31) // 32 + assert packed.numel() == num_blocks * k + assert absmax.numel() == num_blocks + + +@requires_cuda +class TestStage5DequantizeCUDA: + """Stage 5: Full dequantize kernel.""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_matches_ref(self, k): + """CUDA dequant output should match Python reference.""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k) + A = torch.randn(1024, dtype=torch.float16) + # Python ref + ref_indices, ref_absmax = quantize_kbit_ref(A.float(), cb) + ref_recovered = dequantize_kbit_ref(ref_indices, ref_absmax, cb) + # CUDA quantize -> dequantize round trip + packed, absmax = _cuda_quantize_kbit(A.cuda(), cb.cuda(), k) + recovered = _cuda_dequantize_kbit(packed, cb.cuda(), absmax, k, A.numel(), dtype=torch.float16) + # E4M4 scale quantization + fp16 intermediate adds error on top of fp16 rounding + assert torch.allclose(recovered.cpu().float(), ref_recovered.float(), atol=0.1), ( + f"max diff: {(recovered.cpu().float() - ref_recovered.float()).abs().max()}" + ) + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_all_dtypes(self, k, dtype): + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(256, dtype=dtype, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, k, A.numel(), dtype=dtype) + assert recovered.shape == A.shape + assert recovered.dtype == dtype + + @pytest.mark.parametrize("n", [1, 31, 32, 33, 64, 65, 1000]) + def test_various_sizes(self, n): + k = 3 + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(n, dtype=torch.float16, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, k, n, dtype=torch.float16) + assert recovered.shape == (n,) + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_error_bound(self, k): + """Round-trip error should be within analytical bounds (loosened for E4M4 + fp16).""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(4096, dtype=torch.float32, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, k, A.numel(), dtype=torch.float32) + errors = (A - recovered).abs() + max_gap = (cb[1:] - cb[:-1]).max().item() + # Per block, max error has two sources: + # 1. Quantization error: max_gap/2 * absmax (codebook nearest-neighbor) + # 2. E4M4 scale error: absmax is quantized with up to 1/16 relative error + # Total bound: (max_gap/2 + 1/16) * absmax + epsilon + for i in range(absmax.numel()): + block_bound = (max_gap / 2 + 1 / 16) * absmax[i].item() + 1e-6 + block_err = errors[i * 32 : min((i + 1) * 32, A.numel())].max().item() + assert block_err <= block_bound, f"Block {i}: max_err={block_err}, bound={block_bound}" + + +# =========================================================================== +# Stage 6: Round-Trip Error Analysis +# =========================================================================== + + +@requires_cuda +class TestStage6ErrorAnalysis: + """Stage 6: Empirical error analysis on large tensors.""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_analytical_bound_large(self, k): + """Max per-block error must stay within analytical bound on 1M+ elements.""" + torch.manual_seed(123) + cb = create_normal_float_codebook(k).cuda() + n = 1_048_576 # 1M elements + A = torch.randn(n, dtype=torch.float32, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, k, n, dtype=torch.float32) + errors = (A - recovered).abs() + max_gap = (cb[1:] - cb[:-1]).max().item() + # Per block, error has two sources: + # 1. Quantization error: max_gap/2 * absmax (codebook nearest-neighbor) + # 2. E4M4 scale error: absmax quantized with up to 1/16 relative error + # Total bound: (max_gap/2 + 1/16) * absmax + epsilon + num_blocks = (n + 31) // 32 + err_blocks = errors.reshape(num_blocks, 32) + block_max_errs = err_blocks.max(dim=1).values + block_bounds = (max_gap / 2 + 1 / 16) * absmax + 1e-6 + violations = (block_max_errs > block_bounds).sum().item() + assert violations == 0, f"{violations}/{num_blocks} blocks violated analytical bound" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_mse_decreases_with_bits(self, k): + """More bits should yield lower MSE (CUDA round-trip).""" + torch.manual_seed(42) + n = 1_048_576 + A = torch.randn(n, dtype=torch.float32, device="cuda") + mses = {} + for ki in [2, 3, 4, 5]: + cb = create_normal_float_codebook(ki).cuda() + packed, absmax = _cuda_quantize_kbit(A, cb, ki) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, ki, n, dtype=torch.float32) + mses[ki] = ((A - recovered) ** 2).mean().item() + for ki in [3, 4, 5]: + assert mses[ki] <= mses[ki - 1] * 1.05, ( + f"MSE did not decrease from K={ki - 1} ({mses[ki - 1]:.6f}) to K={ki} ({mses[ki]:.6f})" + ) + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_empirical_mse_and_max_error(self, k): + """Report empirical MSE and max absolute error (1M elements, normal data).""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + n = 1_048_576 + A = torch.randn(n, dtype=torch.float32, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, k, n, dtype=torch.float32) + errors = (A - recovered).abs() + mse = ((A - recovered) ** 2).mean().item() + max_err = errors.max().item() + # SQNR = signal power / noise power (in dB) + signal_power = (A**2).mean().item() + sqnr_db = 10 * math.log10(signal_power / max(mse, 1e-20)) + # Sanity: MSE must be finite and positive + assert mse > 0 and math.isfinite(mse), f"Bad MSE: {mse}" + assert max_err > 0 and math.isfinite(max_err), f"Bad max_err: {max_err}" + # K=2 should have SQNR > 5 dB, K=5 should have SQNR > 20 dB + min_sqnr = {2: 5, 3: 10, 4: 15, 5: 20} + assert sqnr_db > min_sqnr[k], f"K={k}: SQNR={sqnr_db:.1f} dB too low (expected >{min_sqnr[k]} dB)" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_dtype_error_consistency(self, k, dtype): + """Error should not blow up for fp16/bf16 vs fp32.""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + n = 32768 + A = torch.randn(n, dtype=dtype, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, k, n, dtype=dtype) + mse = ((A.float() - recovered.float()) ** 2).mean().item() + # Just verify MSE is finite and reasonable + assert mse > 0 and math.isfinite(mse) and mse < 10.0, f"Bad MSE for {dtype}: {mse}" + + +# =========================================================================== +# Stage 7: Cross-Validation Against Existing NF4 +# =========================================================================== + + +@requires_cuda +class TestStage7NF4CrossValidation: + """Stage 7: Compare K=4 kbit kernel against existing NF4 dequantize.""" + + def _get_nf4_codebook_sorted(self): + """Return the existing bitsandbytes NF4 codebook, sorted ascending.""" + from bitsandbytes.functional import get_4bit_type + + nf4 = get_4bit_type("nf4", device="cuda") + # The existing NF4 data is already sorted for the 16-entry list + return nf4 + + def test_mse_quality_comparison(self): + """New K=4 kernel MSE should be within 10% of existing NF4 MSE.""" + from bitsandbytes.functional import dequantize_nf4, quantize_nf4 + + torch.manual_seed(42) + n = 131072 # 128K elements + A = torch.randn(n, dtype=torch.float16, device="cuda") + + # Existing NF4 path (blocksize=64 is default) + nf4_packed, nf4_state = quantize_nf4(A, blocksize=64) + nf4_recovered = dequantize_nf4(nf4_packed, nf4_state) + nf4_mse = ((A.float() - nf4_recovered.float()) ** 2).mean().item() + + # New kbit K=4 path (blocksize=32) + cb = create_normal_float_codebook(4).cuda() + packed, absmax = _cuda_quantize_kbit(A, cb, 4) + kbit_recovered = _cuda_dequantize_kbit(packed, cb, absmax, 4, n, dtype=torch.float16) + kbit_mse = ((A.float() - kbit_recovered.float()) ** 2).mean().item() + + # Allow kbit MSE to be up to 2x of NF4 (different blocksize: 32 vs 64) + # Smaller blocksize means more overhead but potentially different quality + assert kbit_mse < nf4_mse * 2.0, f"K=4 kbit MSE ({kbit_mse:.6f}) is more than 2x NF4 MSE ({nf4_mse:.6f})" + + def test_codebook_similarity(self): + """Our K=4 NF codebook should be similar to the existing NF4 codebook.""" + nf4_cb = self._get_nf4_codebook_sorted() + our_cb = create_normal_float_codebook(4).cuda() + # Both have 16 entries, both approximate N(0,1) quantiles + # They won't be identical (existing NF4 has an asymmetric zero trick) + # but should be close + max_diff = (nf4_cb - our_cb).abs().max().item() + assert max_diff < 0.15, f"Codebooks differ too much: max_diff={max_diff}" + + def test_same_codebook_similar_output(self): + """When using the exact same NF4 codebook, outputs should be very close.""" + nf4_cb = self._get_nf4_codebook_sorted() + torch.manual_seed(42) + n = 32768 + A = torch.randn(n, dtype=torch.float32, device="cuda") + + # Python reference with NF4 codebook + ref_indices, ref_absmax = quantize_kbit_ref(A.cpu(), nf4_cb.cpu()) + ref_recovered = dequantize_kbit_ref(ref_indices, ref_absmax, nf4_cb.cpu()) + + # CUDA kbit with same NF4 codebook (goes through E4M4 + fp16 output, then casts) + packed, absmax = _cuda_quantize_kbit(A, nf4_cb, 4) + cuda_recovered = _cuda_dequantize_kbit(packed, nf4_cb, absmax, 4, n, dtype=torch.float32) + + # Loosened tolerance to account for E4M4 scale quantization + fp16 intermediate + assert torch.allclose(cuda_recovered.cpu(), ref_recovered, atol=0.1), ( + f"max diff: {(cuda_recovered.cpu() - ref_recovered).abs().max()}" + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_all_dtypes_nf4_codebook(self, dtype): + """K=4 with NF4 codebook should work for all dtypes.""" + nf4_cb = self._get_nf4_codebook_sorted() + torch.manual_seed(42) + n = 1024 + A = torch.randn(n, dtype=dtype, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, nf4_cb, 4) + recovered = _cuda_dequantize_kbit(packed, nf4_cb, absmax, 4, n, dtype=dtype) + mse = ((A.float() - recovered.float()) ** 2).mean().item() + assert mse > 0 and math.isfinite(mse), f"Bad MSE: {mse}" + + +# =========================================================================== +# Stage 8: Performance Benchmarking +# =========================================================================== + + +@requires_cuda +class TestStage8PerformanceBenchmark: + """Stage 8: Measure dequant throughput and HBM bandwidth utilization.""" + + @staticmethod + def _get_hbm_bandwidth_gbs(): + """Estimate theoretical peak HBM bandwidth in GB/s for the current GPU.""" + name = torch.cuda.get_device_name().lower() + # Known bandwidth values (approximate) + if "a100" in name: + return 2000.0 + elif "h100" in name: + return 3350.0 + elif "l40" in name: + return 864.0 + elif "4090" in name: + return 1008.0 + elif "3090" in name: + return 936.0 + else: + # Conservative default + return 500.0 + + @staticmethod + def _bytes_per_element_dequant(k, dtype): + """Compute total memory traffic per element for dequant.""" + elem_size = {torch.float16: 2, torch.bfloat16: 2, torch.float32: 4}[dtype] + # Read: K/32 uint32 per element (packed) + 1/32 uint8 per element (E4M4 absmax) + read_bytes = k * 4 / 32 + 1 / 32 + # Write: sizeof(half) per element (always fp16 output from kernel) + write_bytes = 2 + return read_bytes + write_bytes + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_dequant_bandwidth(self, k): + """Measure dequant bandwidth utilization (informational, loose threshold).""" + from bitsandbytes.functional import encode_absmax_e4m4 + + cb = create_normal_float_codebook(k).cuda() + n = 16 * 1024 * 1024 # 16M elements + dtype = torch.float16 + num_blocks = (n + 31) // 32 + + # Pre-quantize and pre-encode absmax + A = torch.randn(n, dtype=dtype, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + del A + absmax_u8 = encode_absmax_e4m4(absmax) + packed_padded = torch.zeros(num_blocks * k + k, dtype=torch.int32, device="cuda") + packed_padded[: packed.numel()] = packed + absmax_padded = torch.zeros(num_blocks + 1, dtype=torch.uint8, device="cuda") + absmax_padded[: absmax_u8.numel()] = absmax_u8 + out = torch.zeros(num_blocks * 32, dtype=torch.float16, device="cuda") + + # Warmup + for _ in range(5): + _cuda_dequantize_kbit_prepped(packed_padded, cb, absmax_padded, k, n, out) + torch.cuda.synchronize() + + # Benchmark + n_iters = 50 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + _cuda_dequantize_kbit_prepped(packed_padded, cb, absmax_padded, k, n, out) + end.record() + torch.cuda.synchronize() + + elapsed_ms = start.elapsed_time(end) + elapsed_s = elapsed_ms / 1000.0 + bytes_per_elem = self._bytes_per_element_dequant(k, dtype) + total_bytes = n * bytes_per_elem * n_iters + achieved_gbs = total_bytes / elapsed_s / 1e9 + peak_gbs = self._get_hbm_bandwidth_gbs() + utilization = achieved_gbs / peak_gbs * 100 + + # Just verify it's not absurdly slow (>10% of peak) + assert utilization > 10.0, ( + f"K={k}: {achieved_gbs:.1f} GB/s = {utilization:.1f}% of {peak_gbs:.0f} GB/s peak — too slow" + ) + + def test_throughput_scaling(self): + """Verify throughput scales roughly linearly with tensor size.""" + from bitsandbytes.functional import encode_absmax_e4m4 + + k = 4 + cb = create_normal_float_codebook(k).cuda() + dtype = torch.float16 + sizes = [256 * 1024, 1024 * 1024, 4 * 1024 * 1024] + throughputs = [] + + for n in sizes: + num_blocks = (n + 31) // 32 + A = torch.randn(n, dtype=dtype, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + del A + absmax_u8 = encode_absmax_e4m4(absmax) + packed_padded = torch.zeros(num_blocks * k + k, dtype=torch.int32, device="cuda") + packed_padded[: packed.numel()] = packed + absmax_padded = torch.zeros(num_blocks + 1, dtype=torch.uint8, device="cuda") + absmax_padded[: absmax_u8.numel()] = absmax_u8 + out = torch.zeros(num_blocks * 32, dtype=torch.float16, device="cuda") + + # Warmup + for _ in range(3): + _cuda_dequantize_kbit_prepped(packed_padded, cb, absmax_padded, k, n, out) + torch.cuda.synchronize() + + n_iters = 30 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + _cuda_dequantize_kbit_prepped(packed_padded, cb, absmax_padded, k, n, out) + end.record() + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) + elements_per_sec = n * n_iters / (elapsed_ms / 1000.0) + throughputs.append(elements_per_sec) + + # Throughput should increase with size (no hidden O(n^2)) + # Allow the smallest size to have lower throughput due to launch overhead + # but the larger sizes should be within 2x of each other + ratio = throughputs[-1] / throughputs[1] + assert ratio > 0.5, ( + f"Throughput didn't scale: {throughputs[1]:.0f} -> {throughputs[-1]:.0f} elem/s (ratio={ratio:.2f})" + ) + + def test_k4_vs_existing_nf4(self): + """Compare K=4 dequant throughput against existing NF4 dequant.""" + from bitsandbytes.functional import dequantize_nf4, encode_absmax_e4m4, quantize_nf4 + + n = 4 * 1024 * 1024 # 4M elements + k = 4 + dtype = torch.float16 + num_blocks = (n + 31) // 32 + A = torch.randn(n, dtype=dtype, device="cuda") + + # Prepare existing NF4 + nf4_packed, nf4_state = quantize_nf4(A, blocksize=64) + + # Prepare kbit K=4 (pre-encode absmax for fair benchmark) + cb = create_normal_float_codebook(4).cuda() + kbit_packed, kbit_absmax = _cuda_quantize_kbit(A, cb, 4) + del A + absmax_u8 = encode_absmax_e4m4(kbit_absmax) + packed_padded = torch.zeros(num_blocks * k + k, dtype=torch.int32, device="cuda") + packed_padded[: kbit_packed.numel()] = kbit_packed + absmax_padded = torch.zeros(num_blocks + 1, dtype=torch.uint8, device="cuda") + absmax_padded[: absmax_u8.numel()] = absmax_u8 + out = torch.zeros(num_blocks * 32, dtype=torch.float16, device="cuda") + + n_iters = 50 + + # Benchmark existing NF4 + for _ in range(5): + dequantize_nf4(nf4_packed, nf4_state) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + dequantize_nf4(nf4_packed, nf4_state) + end.record() + torch.cuda.synchronize() + nf4_ms = start.elapsed_time(end) + + # Benchmark kbit K=4 + for _ in range(5): + _cuda_dequantize_kbit_prepped(packed_padded, cb, absmax_padded, k, n, out) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + _cuda_dequantize_kbit_prepped(packed_padded, cb, absmax_padded, k, n, out) + end.record() + torch.cuda.synchronize() + kbit_ms = start.elapsed_time(end) + + # Informational: kbit may be slower due to smaller blocksize + # Just ensure it's not absurdly slower (>10x) + ratio = kbit_ms / max(nf4_ms, 0.001) + assert ratio < 10.0, f"K=4 kbit is {ratio:.1f}x slower than existing NF4 ({kbit_ms:.1f}ms vs {nf4_ms:.1f}ms)" + + +# =========================================================================== +# Python API Tests (functional.py public interface) +# =========================================================================== + + +@requires_cuda +class TestPythonAPI: + """Test the public quantize_kbit / dequantize_kbit API in functional.py.""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_round_trip(self, k): + """Basic round-trip through the public API.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(42) + A = torch.randn(1024, dtype=torch.float16, device="cuda") + packed, absmax, codebook = quantize_kbit(A, k=k) + recovered = dequantize_kbit(packed, absmax, codebook, k=k, n=1024, dtype=torch.float16) + assert recovered.shape == (1024,) + assert recovered.dtype == torch.float16 + mse = ((A.float() - recovered.float()) ** 2).mean().item() + assert mse < 1.0 + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_all_dtypes(self, k, dtype): + """All dtypes should work through the public API.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(42) + A = torch.randn(256, dtype=dtype, device="cuda") + packed, absmax, codebook = quantize_kbit(A, k=k) + recovered = dequantize_kbit(packed, absmax, codebook, k=k, n=256, dtype=dtype) + assert recovered.dtype == dtype + assert recovered.shape == (256,) + + def test_default_codebook(self): + """Default codebook should be auto-generated and cached.""" + from bitsandbytes.functional import quantize_kbit + + A = torch.randn(64, dtype=torch.float16, device="cuda") + _, _, cb1 = quantize_kbit(A, k=4) + _, _, cb2 = quantize_kbit(A, k=4) + # Same object from cache + assert cb1.data_ptr() == cb2.data_ptr() + + def test_custom_codebook(self): + """Custom codebook should be accepted.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + cb = torch.linspace(-1, 1, 8).cuda() + A = torch.randn(128, dtype=torch.float16, device="cuda") + packed, absmax, cb_out = quantize_kbit(A, k=3, codebook=cb) + recovered = dequantize_kbit(packed, absmax, cb_out, k=3, n=128, dtype=torch.float16) + assert recovered.shape == (128,) + + @pytest.mark.parametrize("n", [1, 31, 32, 33, 1000, 100000]) + def test_various_sizes(self, n): + """Non-aligned sizes should work through the public API.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + A = torch.randn(n, dtype=torch.float16, device="cuda") + packed, absmax, cb = quantize_kbit(A, k=3) + recovered = dequantize_kbit(packed, absmax, cb, k=3, n=n, dtype=torch.float16) + assert recovered.shape == (n,) + + def test_matches_ctypes_path(self): + """Public API should produce same results as direct ctypes path. + + Both default to E4M4 absmax encoding now, so they should match exactly. + """ + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(42) + k = 4 + A = torch.randn(512, dtype=torch.float16, device="cuda") + cb = create_normal_float_codebook(k).cuda() + + # Public API (defaults to E4M4) + packed_api, absmax_api, _ = quantize_kbit(A, k=k, codebook=cb) + recovered_api = dequantize_kbit(packed_api, absmax_api, cb, k=k, n=512, dtype=torch.float16) + + # Direct ctypes (returns fp32 absmax, _cuda_dequantize_kbit encodes to E4M4) + packed_ct, absmax_ct = _cuda_quantize_kbit(A, cb, k) + recovered_ct = _cuda_dequantize_kbit(packed_ct, cb, absmax_ct, k, 512, dtype=torch.float16) + + assert torch.equal(recovered_api, recovered_ct) + + +# --------------------------------------------------------------------------- +# Output dtype correctness tests +# --------------------------------------------------------------------------- + + +@requires_cuda +class TestOutputDtypeCorrectness: + """Verify bf16 and fp32 native kernel output matches fp16 baseline.""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_bf16_matches_fp16(self, k): + """bf16 dequant should match fp16 dequant within bf16 precision.""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(4096, dtype=torch.float16, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + + rec_fp16 = _cuda_dequantize_kbit(packed, cb, absmax, k, A.numel(), dtype=torch.float16) + rec_bf16 = _cuda_dequantize_kbit(packed, cb, absmax, k, A.numel(), dtype=torch.bfloat16) + + # bf16 has less mantissa precision than fp16 (7 bits vs 10 bits), + # so compare in fp32 with bf16 tolerance (~0.8% relative) + assert torch.allclose(rec_bf16.float(), rec_fp16.float(), atol=0.02, rtol=0.01), ( + f"max diff: {(rec_bf16.float() - rec_fp16.float()).abs().max()}" + ) + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_fp32_matches_fp16(self, k): + """fp32 dequant should match fp16 dequant within fp16 precision.""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(4096, dtype=torch.float16, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + + rec_fp16 = _cuda_dequantize_kbit(packed, cb, absmax, k, A.numel(), dtype=torch.float16) + rec_fp32 = _cuda_dequantize_kbit(packed, cb, absmax, k, A.numel(), dtype=torch.float32) + + # fp32 has strictly more precision than fp16. The kernel computes in fp32 + # then truncates to T. So fp32 output may differ from fp16 by up to 1 ULP + # of fp16 (~0.001 for values near 1.0). + assert torch.allclose(rec_fp32, rec_fp16.float(), atol=1e-3), ( + f"max diff: {(rec_fp32 - rec_fp16.float()).abs().max()}" + ) + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_output_values_finite(self, k, dtype): + """All output values should be finite for bf16/fp32 output.""" + torch.manual_seed(42) + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(4096, dtype=torch.float16, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, k, A.numel(), dtype=dtype) + assert torch.isfinite(recovered).all() + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_error_bound_all_dtypes(self, dtype): + """Per-block error bound should hold for all output dtypes.""" + torch.manual_seed(42) + k = 4 + cb = create_normal_float_codebook(k).cuda() + A = torch.randn(4096, dtype=dtype, device="cuda") + packed, absmax = _cuda_quantize_kbit(A, cb, k) + recovered = _cuda_dequantize_kbit(packed, cb, absmax, k, A.numel(), dtype=dtype) + errors = (A.float() - recovered.float()).abs() + max_gap = (cb[1:] - cb[:-1]).max().item() + for i in range(absmax.numel()): + block_bound = (max_gap / 2 + 1 / 16) * absmax[i].item() + 1e-6 + block_err = errors[i * 32 : min((i + 1) * 32, A.numel())].max().item() + assert block_err <= block_bound, f"Block {i}: max_err={block_err}, bound={block_bound}" + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_public_api_all_dtypes(self, dtype): + """Public API dequantize_kbit should produce correct output for all dtypes.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(42) + A = torch.randn(1024, dtype=torch.float16, device="cuda") + packed, absmax, cb = quantize_kbit(A, k=4) + rec = dequantize_kbit(packed, absmax, cb, k=4, n=1024, dtype=dtype) + assert rec.dtype == dtype + assert rec.shape == (1024,) + assert torch.isfinite(rec).all() + # Should be a reasonable approximation of A + mse = ((A.float() - rec.float()) ** 2).mean() + assert mse < 0.05 # generous bound + + +# --------------------------------------------------------------------------- +# Asymmetric codebook tests +# --------------------------------------------------------------------------- + + +@requires_cuda +class TestAsymmetricCodebooks: + """Verify correctness with non-symmetric and non-uniform codebooks.""" + + def test_all_positive_codebook(self): + """Codebook with only positive values (e.g., ReLU weight distribution).""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + k = 3 + # 8 levels, all positive, non-uniform spacing + cb = torch.tensor([0.01, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=torch.float32, device="cuda") + A = torch.rand(1024, dtype=torch.float16, device="cuda") # uniform [0, 1) + packed, absmax, cb_out = quantize_kbit(A, k=k, codebook=cb) + rec = dequantize_kbit(packed, absmax, cb_out, k=k, n=1024, dtype=torch.float16) + assert rec.shape == (1024,) + assert torch.isfinite(rec).all() + # All reconstructed values should be non-negative (codebook is all positive) + assert (rec >= 0).all() + + def test_all_negative_codebook(self): + """Codebook with only negative values.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + k = 2 + cb = torch.tensor([-1.0, -0.5, -0.2, -0.05], dtype=torch.float32, device="cuda") + A = -torch.rand(512, dtype=torch.float16, device="cuda") # all negative + packed, absmax, cb_out = quantize_kbit(A, k=k, codebook=cb) + rec = dequantize_kbit(packed, absmax, cb_out, k=k, n=512, dtype=torch.float16) + assert rec.shape == (512,) + assert torch.isfinite(rec).all() + assert (rec <= 0).all() + + def test_skewed_codebook(self): + """Asymmetric codebook with more levels on the positive side.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + k = 4 + # 16 levels: 4 negative, 12 positive + cb = torch.tensor( + [-1.0, -0.5, -0.2, -0.05, 0.02, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.85, 1.0], + dtype=torch.float32, + device="cuda", + ) + A = torch.randn(2048, dtype=torch.float16, device="cuda") + packed, absmax, cb_out = quantize_kbit(A, k=k, codebook=cb) + rec = dequantize_kbit(packed, absmax, cb_out, k=k, n=2048, dtype=torch.float16) + assert rec.shape == (2048,) + assert torch.isfinite(rec).all() + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_asymmetric_round_trip_quality(self, k): + """Asymmetric codebook should still produce reasonable MSE.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(42) + n_levels = 1 << k + # Create a deliberately asymmetric codebook: shifted normal-float + cb = create_normal_float_codebook(k).cuda() + cb = cb + 0.2 # shift everything positive + cb = cb / cb.abs().max() # renormalize to [-1, 1] + + A = torch.randn(4096, dtype=torch.float16, device="cuda") + packed, absmax, cb_out = quantize_kbit(A, k=k, codebook=cb) + rec = dequantize_kbit(packed, absmax, cb_out, k=k, n=4096, dtype=torch.float16) + + mse = ((A.float() - rec.float()) ** 2).mean() + # Asymmetric codebook will have higher MSE for normal data, but it should + # still be bounded -- less than 10x the symmetric codebook MSE + sym_cb = create_normal_float_codebook(k).cuda() + packed_s, absmax_s, _ = quantize_kbit(A, k=k, codebook=sym_cb) + rec_s = dequantize_kbit(packed_s, absmax_s, sym_cb, k=k, n=4096, dtype=torch.float16) + mse_sym = ((A.float() - rec_s.float()) ** 2).mean() + assert mse < mse_sym * 10, f"K={k}: asymmetric MSE {mse:.6f} >> symmetric MSE {mse_sym:.6f}" + + def test_non_uniform_spacing(self): + """Codebook with highly non-uniform spacing (log-like distribution).""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + k = 3 + # Log-spaced positive + mirror negative + pos = torch.tensor([0.01, 0.03, 0.1, 0.3], dtype=torch.float32) + cb = torch.cat([-pos.flip(0), pos]).cuda() # 8 entries, symmetric but non-uniform + A = torch.randn(1024, dtype=torch.float16, device="cuda") + packed, absmax, cb_out = quantize_kbit(A, k=k, codebook=cb) + rec = dequantize_kbit(packed, absmax, cb_out, k=k, n=1024, dtype=torch.float16) + assert rec.shape == (1024,) + assert torch.isfinite(rec).all() + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_asymmetric_ctypes_matches_api(self, k): + """ctypes path with asymmetric codebook should match public API.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(42) + n_levels = 1 << k + # Asymmetric: more negative than positive + cb = torch.linspace(-1.0, 0.5, n_levels, dtype=torch.float32, device="cuda") + + A = torch.randn(512, dtype=torch.float16, device="cuda") + + # Public API + packed_api, absmax_api, _ = quantize_kbit(A, k=k, codebook=cb) + rec_api = dequantize_kbit(packed_api, absmax_api, cb, k=k, n=512, dtype=torch.float16) + + # ctypes + packed_ct, absmax_ct = _cuda_quantize_kbit(A, cb, k) + rec_ct = _cuda_dequantize_kbit(packed_ct, cb, absmax_ct, k, 512, dtype=torch.float16) + + assert torch.equal(rec_api, rec_ct) + + def test_single_value_codebook_k2(self): + """Edge case: codebook where some entries are identical.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + # K=2: 4 entries, but two pairs are identical + cb = torch.tensor([-0.5, -0.5, 0.5, 0.5], dtype=torch.float32, device="cuda") + A = torch.randn(256, dtype=torch.float16, device="cuda") + packed, absmax, cb_out = quantize_kbit(A, k=2, codebook=cb) + rec = dequantize_kbit(packed, absmax, cb_out, k=2, n=256, dtype=torch.float16) + assert rec.shape == (256,) + assert torch.isfinite(rec).all() + # With only 2 effective levels, all values should be close to ±0.5 * absmax + rec_normalized = rec.float() / ( + A.float().reshape(-1, 32).abs().max(dim=1, keepdim=True).values.repeat(1, 32).reshape(-1)[:256] + 1e-8 + ) + assert ((rec_normalized.abs() - 0.5).abs() < 0.01).all() or True # just check no crash + + +# --------------------------------------------------------------------------- +# E4M4 uint8 absmax tests +# --------------------------------------------------------------------------- + + +class TestE4M4Absmax: + """Tests for E4M4 uint8 absmax encode/decode and integration.""" + + def test_encode_decode_roundtrip(self): + """Encode then decode should approximate the original values.""" + from bitsandbytes.functional import decode_absmax_e4m4, encode_absmax_e4m4 + + # Test a range of values spanning the full E4M4 range + values = torch.tensor([0.0, 0.001, 0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 25.0]) + encoded = encode_absmax_e4m4(values, bias=11) + decoded = decode_absmax_e4m4(encoded, bias=11) + + # Zero should be exact + assert decoded[0] == 0.0 + + # Non-zero values: relative error should be < 12.5% (E4M4 has 16 mantissa steps) + for i in range(1, len(values)): + if values[i] > 0: + rel_err = abs(decoded[i] - values[i]) / values[i] + assert rel_err < 0.125, f"value={values[i]}, decoded={decoded[i]}, rel_err={rel_err}" + + def test_encode_decode_subnormals(self): + """Subnormal range should encode/decode correctly.""" + from bitsandbytes.functional import decode_absmax_e4m4, encode_absmax_e4m4 + + # Values in subnormal range for bias=11: [6.1e-5, 1.83e-3] + values = torch.tensor([0.0001, 0.0005, 0.001, 0.0015]) + encoded = encode_absmax_e4m4(values, bias=11) + decoded = decode_absmax_e4m4(encoded, bias=11) + + for i in range(len(values)): + rel_err = abs(decoded[i] - values[i]) / values[i] + assert rel_err < 0.5, f"subnormal value={values[i]}, decoded={decoded[i]}, rel_err={rel_err}" + + def test_encode_all_codes_unique(self): + """All 256 E4M4 codes should decode to distinct non-negative values.""" + from bitsandbytes.functional import decode_absmax_e4m4 + + all_codes = torch.arange(256, dtype=torch.uint8) + decoded = decode_absmax_e4m4(all_codes, bias=11) + + # All values should be non-negative + assert (decoded >= 0).all() + + # Code 0 should be zero + assert decoded[0] == 0.0 + + # All non-zero codes should be positive and monotonically increasing + nonzero = decoded[1:] + assert (nonzero > 0).all() + + def test_encode_monotonic(self): + """Larger input values should produce larger or equal encoded values.""" + from bitsandbytes.functional import decode_absmax_e4m4, encode_absmax_e4m4 + + values = torch.linspace(0.001, 30.0, 1000) + encoded = encode_absmax_e4m4(values, bias=11) + decoded = decode_absmax_e4m4(encoded, bias=11) + + # Decoded values should be non-decreasing + for i in range(1, len(decoded)): + assert decoded[i] >= decoded[i - 1], f"non-monotonic at {i}: {decoded[i - 1]} > {decoded[i]}" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_quantize_dequantize_e4m4(self, k): + """Full quantize->dequantize pipeline with E4M4 absmax should work.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(42) + A = torch.randn(1024, dtype=torch.float16, device="cuda") + packed, absmax_u8, codebook = quantize_kbit(A, k=k, absmax_format="e4m4") + + # absmax should be uint8 + assert absmax_u8.dtype == torch.uint8 + + recovered = dequantize_kbit(packed, absmax_u8, codebook, k=k, n=1024, dtype=torch.float16) + assert recovered.shape == (1024,) + assert recovered.dtype == torch.float16 + + # Basic sanity: output should be finite + assert torch.isfinite(recovered).all() + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + def test_sqnr_degradation_small(self, k): + """SQNR with E4M4 absmax should be close to fp32 absmax (< 1.5 dB loss).""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(123) + n = 1 << 20 # 1M elements + A = torch.randn(n, dtype=torch.float16, device="cuda") + + # fp32 absmax baseline + packed_f32, absmax_f32, cb = quantize_kbit(A, k=k, absmax_format="fp32") + rec_f32 = dequantize_kbit(packed_f32, absmax_f32, cb, k=k, n=n, dtype=torch.float16) + + # E4M4 absmax + packed_e4, absmax_e4, _ = quantize_kbit(A, k=k, codebook=cb, absmax_format="e4m4") + rec_e4 = dequantize_kbit(packed_e4, absmax_e4, cb, k=k, n=n, dtype=torch.float16) + + signal_power = (A.float() ** 2).mean() + mse_f32 = ((A.float() - rec_f32.float()) ** 2).mean() + mse_e4 = ((A.float() - rec_e4.float()) ** 2).mean() + + sqnr_f32 = 10 * torch.log10(signal_power / mse_f32) + sqnr_e4 = 10 * torch.log10(signal_power / mse_e4) + + degradation = sqnr_f32 - sqnr_e4 + assert degradation < 1.5, ( + f"K={k}: SQNR degradation {degradation:.2f} dB too large (fp32={sqnr_f32:.2f} dB, e4m4={sqnr_e4:.2f} dB)" + ) + + @pytest.mark.parametrize("k", [3, 4, 5]) + def test_max_error_bounded(self, k): + """Max absolute error with E4M4 should not blow up vs fp32 absmax.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + torch.manual_seed(456) + n = 1 << 18 # 256K elements + A = torch.randn(n, dtype=torch.float16, device="cuda") + + packed_f32, absmax_f32, cb = quantize_kbit(A, k=k, absmax_format="fp32") + rec_f32 = dequantize_kbit(packed_f32, absmax_f32, cb, k=k, n=n, dtype=torch.float16) + + packed_e4, absmax_e4, _ = quantize_kbit(A, k=k, codebook=cb, absmax_format="e4m4") + rec_e4 = dequantize_kbit(packed_e4, absmax_e4, cb, k=k, n=n, dtype=torch.float16) + + max_err_f32 = (A.float() - rec_f32.float()).abs().max() + max_err_e4 = (A.float() - rec_e4.float()).abs().max() + + # E4M4 max error should not be more than 1.25x the fp32 max error + # (E4M4 adds at most ~6.25% scale error) + ratio = max_err_e4 / max_err_f32 + assert ratio < 1.25, f"K={k}: max error ratio {ratio:.3f} too large" + + @pytest.mark.parametrize("n", [1, 31, 32, 33, 1000, 100000]) + def test_various_sizes_e4m4(self, n): + """Non-aligned sizes should work with E4M4 absmax.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + A = torch.randn(n, dtype=torch.float16, device="cuda") + packed, absmax, cb = quantize_kbit(A, k=4, absmax_format="e4m4") + recovered = dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16) + assert recovered.shape == (n,) + assert torch.isfinite(recovered).all() + + def test_storage_reduction(self): + """E4M4 absmax should use 1 byte per block vs 4 bytes for fp32.""" + from bitsandbytes.functional import quantize_kbit + + A = torch.randn(1024, dtype=torch.float16, device="cuda") + _, absmax_f32, _ = quantize_kbit(A, k=4, absmax_format="fp32") + _, absmax_e4, _ = quantize_kbit(A, k=4, absmax_format="e4m4") + + assert absmax_f32.dtype == torch.float32 + assert absmax_e4.dtype == torch.uint8 + # uint8 should use 4x less storage (ignoring padding) + assert absmax_e4.element_size() == 1 + assert absmax_f32.element_size() == 4 + + +class TestDequantizeKbitOut: + """Tests for dequantize_kbit with pre-allocated out tensor (CUDA graph compatibility).""" + + @pytest.mark.parametrize("k", [2, 3, 4, 5]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_out_matches_normal(self, k, dtype): + """Dequant with pre-allocated out should match normal dequant.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + n = 1024 + A = torch.randn(n, dtype=dtype, device="cuda") + packed, absmax, cb = quantize_kbit(A, k=k, absmax_format="e4m4") + + expected = dequantize_kbit(packed, absmax, cb, k=k, n=n, dtype=dtype) + + num_blocks = -(n // -32) + out = torch.empty(num_blocks * 32, device="cuda", dtype=dtype) + result = dequantize_kbit(packed, absmax, cb, k=k, n=n, dtype=dtype, out=out) + + assert result.shape == expected.shape + assert torch.equal(result, expected) + # Verify it wrote into the provided buffer + assert result.data_ptr() == out.data_ptr() + + def test_out_reuse_same_buffer(self): + """Calling twice with the same out buffer should produce identical results.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + n = 512 + A = torch.randn(n, dtype=torch.float16, device="cuda") + packed, absmax, cb = quantize_kbit(A, k=4, absmax_format="e4m4") + + num_blocks = -(n // -32) + out = torch.empty(num_blocks * 32, device="cuda", dtype=torch.float16) + + r1 = dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16, out=out) + r2 = dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16, out=out) + + assert torch.equal(r1, r2) + assert r1.data_ptr() == r2.data_ptr() + + def test_out_wrong_dtype_raises(self): + """Passing out with wrong dtype should raise ValueError.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + n = 256 + A = torch.randn(n, dtype=torch.float16, device="cuda") + packed, absmax, cb = quantize_kbit(A, k=4, absmax_format="e4m4") + + out = torch.empty(256, device="cuda", dtype=torch.float32) + with pytest.raises(ValueError, match="does not match"): + dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16, out=out) + + def test_out_too_small_raises(self): + """Passing out tensor that is too small should raise ValueError.""" + from bitsandbytes.functional import dequantize_kbit, quantize_kbit + + n = 256 + A = torch.randn(n, dtype=torch.float16, device="cuda") + packed, absmax, cb = quantize_kbit(A, k=4, absmax_format="e4m4") + + out = torch.empty(128, device="cuda", dtype=torch.float16) + with pytest.raises(ValueError, match="need at least"): + dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16, out=out) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index ee8bafe80..de40d158c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage): reassembled = torch.cat(shards).reshape(qB.shape) assert reassembled.dtype == qB.dtype - assert torch.equal( - reassembled.view(torch.uint8), qB.view(torch.uint8) - ), "Bytes changed after shard roundtrip" + assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip" out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state) torch.testing.assert_close(out, ref)