Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ if(BUILD_CUDA)
set_target_properties(bitsandbytes
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
endif()
if(BUILD_HIP)
Expand Down
1,145 changes: 1,145 additions & 0 deletions agents/flute_kernel_guide.md

Large diffs are not rendered by default.

1,391 changes: 1,391 additions & 0 deletions agents/kbit_gemm_context.md

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,74 @@ def _(
qmap2.dtype == absmax2.dtype == torch.float32,
lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
)


# K-bit blockwise quantization (K=2..5, blocksize=32)

torch.library.define(
"bitsandbytes::quantize_kbit",
"(Tensor A, Tensor codebook, int k) -> (Tensor, Tensor)",
)


@register_fake("bitsandbytes::quantize_kbit")
def _(A: torch.Tensor, codebook: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(codebook.numel() == (1 << k), lambda: f"codebook must have {1 << k} entries for k={k}")
n = A.numel()
num_blocks = -(n // -32)
# packed: num_blocks * k int32 words + k padding words
packed = torch.empty(num_blocks * k + k, device=A.device, dtype=torch.int32)
absmax = torch.empty(num_blocks + 1, device=A.device, dtype=torch.float32)
return packed, absmax


torch.library.define(
"bitsandbytes::dequantize_kbit",
"(Tensor packed, Tensor codebook, Tensor absmax, int k, int n, ScalarType dtype) -> Tensor",
)


@register_fake("bitsandbytes::dequantize_kbit")
def _(
packed: torch.Tensor,
codebook: torch.Tensor,
absmax: torch.Tensor,
k: int,
n: int,
dtype: torch.dtype,
) -> torch.Tensor:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(
absmax.dtype in (torch.float32, torch.uint8),
lambda: f"absmax must be float32 or uint8 (E4M4), got {absmax.dtype}",
)
num_blocks = -(n // -32)
return torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)


torch.library.define(
"bitsandbytes::dequantize_kbit_",
"(Tensor packed, Tensor codebook, Tensor absmax, int k, int n, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)",
)


@register_fake("bitsandbytes::dequantize_kbit_")
def _(
packed: torch.Tensor,
codebook: torch.Tensor,
absmax: torch.Tensor,
k: int,
n: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> torch.Tensor:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(
absmax.dtype in (torch.float32, torch.uint8),
lambda: f"absmax must be float32 or uint8 (E4M4), got {absmax.dtype}",
)
num_blocks = -(n // -32)
torch._check(out.numel() >= num_blocks * 32, lambda: f"out must have at least {num_blocks * 32} elements")
torch._check(out.dtype == dtype, lambda: f"out dtype {out.dtype} must match requested dtype {dtype}")
return out
114 changes: 114 additions & 0 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,117 @@ def _optimizer_update_8bit_blockwise_impl(

register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)


# K-bit blockwise quantization (K=2..5, blocksize=32)

_KBIT_DTYPE_SUFFIX = {
torch.float16: "fp16",
torch.bfloat16: "bf16",
torch.float32: "fp32",
}


@register_kernel("bitsandbytes::quantize_kbit", "cuda")
def _(A: torch.Tensor, codebook: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(
A.dtype in _KBIT_DTYPE_SUFFIX,
lambda: f"quantize_kbit only supports float16/bfloat16/float32, got {A.dtype}",
)
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
torch._check(codebook.numel() == (1 << k), lambda: f"codebook must have {1 << k} entries for k={k}")

n = A.numel()
num_blocks = -(n // -32)
packed = torch.zeros(num_blocks * k + k, device=A.device, dtype=torch.int32)
absmax = torch.zeros(num_blocks + 1, device=A.device, dtype=torch.float32)

with _cuda_device_of(A):
tname = _KBIT_DTYPE_SUFFIX[A.dtype]
fn = getattr(lib, f"cquantize_kbit_{tname}_k{k}")
fn(
get_ptr(codebook),
get_ptr(A),
get_ptr(absmax),
get_ptr(packed),
ct.c_int(n),
)

return packed, absmax


_KBIT_ABSMAX_SUFFIX = {
torch.uint8: "u8abs",
torch.float16: "fp16abs",
}


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the user passes fp32 absmax, the dequant dispatch silently encodes it to E4M4 before calling the kernel. This is a lossy conversion the caller may not expect — they passed fp32 precision but get E4M4 precision. Consider either warning or documenting this behavior.

def _dequantize_kbit_impl(
packed: torch.Tensor,
codebook: torch.Tensor,
absmax: torch.Tensor,
k: int,
n: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(
dtype in _KBIT_DTYPE_SUFFIX,
lambda: f"dequantize_kbit only supports float16/bfloat16/float32, got {dtype}",
)
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
torch._check(
absmax.dtype in (torch.float32, torch.float16, torch.uint8),
lambda: f"absmax must be float32, float16, or uint8 (E4M4), got {absmax.dtype}",
)

# If fp32 absmax, encode to E4M4 first
if absmax.dtype == torch.float32:
from bitsandbytes.functional import encode_absmax_e4m4

absmax = encode_absmax_e4m4(absmax)

tname = _KBIT_DTYPE_SUFFIX[dtype]
aname = _KBIT_ABSMAX_SUFFIX[absmax.dtype]

with _cuda_device_of(packed):
fn = getattr(lib, f"cdequantize_kbit_{tname}_{aname}_k{k}")
fn(
get_ptr(packed),
get_ptr(codebook),
get_ptr(absmax),
get_ptr(out),
ct.c_int(n),
_get_tensor_stream(packed),
)


@register_kernel("bitsandbytes::dequantize_kbit", "cuda")
def _(
packed: torch.Tensor,
codebook: torch.Tensor,
absmax: torch.Tensor,
k: int,
n: int,
dtype: torch.dtype,
) -> torch.Tensor:
num_blocks = -(n // -32)
out = torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)
_dequantize_kbit_impl(packed, codebook, absmax, k, n, dtype, out)
return out


@register_kernel("bitsandbytes::dequantize_kbit_", "cuda")
def _(
packed: torch.Tensor,
codebook: torch.Tensor,
absmax: torch.Tensor,
k: int,
n: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> torch.Tensor:
_dequantize_kbit_impl(packed, codebook, absmax, k, n, dtype, out)
return out
Loading