Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)

# 3. Save state
# 3. Write to out tensor if provided
if out is not None:
out.copy_(output)
output = out

# 4. Save state
ctx.state = quant_state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

Expand Down
42 changes: 42 additions & 0 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,45 @@ def test_matmul_4bit(

if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"], ids=id_formatter("quant_type"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmul_4bit_out_parameter(device, quant_type, dtype, has_bias):
"""Test that matmul_4bit(A, B, out=output) writes the result into output (issue #1235)."""
M, K, N = 32, 64, 48

# Create weight matrix (K, N) and quantize — matmul_4bit computes A @ dequant(B)
W = torch.randn(K, N, device=device, dtype=dtype)
torch.nn.init.xavier_uniform_(W)
B_quant, quant_state = bnb.functional.quantize_4bit(W, quant_type=quant_type)

bias = None
if has_bias:
bias = torch.randn(N, device=device, dtype=dtype)

# --- Test 2D input (matrix path through MatMul4Bit) ---
A_2d = torch.randn(M, K, device=device, dtype=dtype)
expected = bnb.matmul_4bit(A_2d, B_quant, quant_state, bias=bias)

out_2d = torch.zeros(M, N, device=device, dtype=dtype)
returned = bnb.matmul_4bit(A_2d, B_quant, quant_state, out=out_2d, bias=bias)

# out tensor should contain the result
torch.testing.assert_close(out_2d, expected)
# returned value should be the same object as out
assert returned.data_ptr() == out_2d.data_ptr(), "returned tensor should share storage with out"

# --- Test 1D input (gemv path) if on CUDA and blocksize divides K ---
# Skip bias for 1D: the gemv path has a pre-existing shape bug with bias when K != N.
if device == "cuda" and K % quant_state.blocksize == 0 and not has_bias:
A_1d = torch.randn(K, device=device, dtype=dtype)
expected_1d = bnb.matmul_4bit(A_1d, B_quant, quant_state)

out_1d = torch.zeros_like(expected_1d)
returned_1d = bnb.matmul_4bit(A_1d, B_quant, quant_state, out=out_1d)

torch.testing.assert_close(out_1d, expected_1d)
assert returned_1d.data_ptr() == out_1d.data_ptr(), "returned tensor should share storage with out"
4 changes: 1 addition & 3 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage):
reassembled = torch.cat(shards).reshape(qB.shape)

assert reassembled.dtype == qB.dtype
assert torch.equal(
reassembled.view(torch.uint8), qB.view(torch.uint8)
), "Bytes changed after shard roundtrip"
assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip"

out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state)
torch.testing.assert_close(out, ref)
Expand Down