diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index da168e17b..d9e519ac2 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -314,7 +314,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 2. MatmulnN output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) - # 3. Save state + # 3. Write to out tensor if provided + if out is not None: + out.copy_(output) + output = out + + # 4. Save state ctx.state = quant_state ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 7134925c1..b1c087519 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -262,3 +262,45 @@ def test_matmul_4bit( if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"], ids=id_formatter("quant_type")) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) +def test_matmul_4bit_out_parameter(device, quant_type, dtype, has_bias): + """Test that matmul_4bit(A, B, out=output) writes the result into output (issue #1235).""" + M, K, N = 32, 64, 48 + + # Create weight matrix (K, N) and quantize — matmul_4bit computes A @ dequant(B) + W = torch.randn(K, N, device=device, dtype=dtype) + torch.nn.init.xavier_uniform_(W) + B_quant, quant_state = bnb.functional.quantize_4bit(W, quant_type=quant_type) + + bias = None + if has_bias: + bias = torch.randn(N, device=device, dtype=dtype) + + # --- Test 2D input (matrix path through MatMul4Bit) --- + A_2d = torch.randn(M, K, device=device, dtype=dtype) + expected = bnb.matmul_4bit(A_2d, B_quant, quant_state, bias=bias) + + out_2d = torch.zeros(M, N, device=device, dtype=dtype) + returned = bnb.matmul_4bit(A_2d, B_quant, quant_state, out=out_2d, bias=bias) + + # out tensor should contain the result + torch.testing.assert_close(out_2d, expected) + # returned value should be the same object as out + assert returned.data_ptr() == out_2d.data_ptr(), "returned tensor should share storage with out" + + # --- Test 1D input (gemv path) if on CUDA and blocksize divides K --- + # Skip bias for 1D: the gemv path has a pre-existing shape bug with bias when K != N. + if device == "cuda" and K % quant_state.blocksize == 0 and not has_bias: + A_1d = torch.randn(K, device=device, dtype=dtype) + expected_1d = bnb.matmul_4bit(A_1d, B_quant, quant_state) + + out_1d = torch.zeros_like(expected_1d) + returned_1d = bnb.matmul_4bit(A_1d, B_quant, quant_state, out=out_1d) + + torch.testing.assert_close(out_1d, expected_1d) + assert returned_1d.data_ptr() == out_1d.data_ptr(), "returned tensor should share storage with out" diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index ee8bafe80..de40d158c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage): reassembled = torch.cat(shards).reshape(qB.shape) assert reassembled.dtype == qB.dtype - assert torch.equal( - reassembled.view(torch.uint8), qB.view(torch.uint8) - ), "Bytes changed after shard roundtrip" + assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip" out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state) torch.testing.assert_close(out, ref)