Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import logging
from math import prod
from typing import Optional
import warnings
Expand All @@ -8,6 +9,8 @@

import bitsandbytes.functional as F

logger = logging.getLogger(__name__)

# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py

Expand Down Expand Up @@ -123,7 +126,7 @@ def forward(

# Cast A to fp16
if A.dtype != torch.float16 and not _is_compiling():
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
logger.warning("MatMul8bitLt: inputs will be cast from %s to float16 during quantization", A.dtype)

if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])
Expand Down
17 changes: 8 additions & 9 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from typing import Any, Optional, TypeVar, Union, overload
import warnings

import torch
from torch import Tensor, device, dtype, nn
Expand All @@ -20,6 +20,8 @@
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer

logger = logging.getLogger(__name__)

T = TypeVar("T", bound="torch.nn.Module")


Expand Down Expand Up @@ -443,7 +445,7 @@ def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Line
return

if getattr(module, "quant_state", None) is None:
warnings.warn(
logger.warning(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)

Expand Down Expand Up @@ -536,15 +538,13 @@ def set_compute_type(self, x):
if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn(
logger.warning(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
)
warnings.filterwarnings("ignore", message=".*inference.")
if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
warnings.warn(
logger.warning(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
)
warnings.filterwarnings("ignore", message=".*inference or training")

def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
Expand Down Expand Up @@ -877,7 +877,7 @@ def __init__(
blocksize = self.weight.blocksize

if embedding_dim % blocksize != 0:
warnings.warn(
logger.warning(
f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. "
"This will lead to slow inference.",
)
Expand Down Expand Up @@ -1164,9 +1164,8 @@ def forward(self, x):
if self.outlier_dim is None:
tracer = OutlierTracer.get_instance()
if not tracer.is_initialized():
print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
logger.warning("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
outlier_idx = tracer.get_outliers(self.weight)
# print(outlier_idx, tracer.get_hvalue(self.weight))
self.outlier_dim = outlier_idx

if not self.is_quantized:
Expand Down
5 changes: 4 additions & 1 deletion bitsandbytes/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import logging
import shlex
import subprocess

import torch

logger = logging.getLogger(__name__)


def outlier_hook(module, input):
assert isinstance(module, torch.nn.Linear)
Expand Down Expand Up @@ -65,7 +68,7 @@ def get_hvalue(self, weight):

def get_outliers(self, weight):
if not self.is_initialized():
print("Outlier tracer is not initialized...")
logger.warning("Outlier tracer is not initialized...")
return None
hvalue = self.get_hvalue(weight)
if hvalue in self.hvalue2outlier_idx:
Expand Down
34 changes: 17 additions & 17 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import contextlib
import inspect
import logging

import pytest
import torch
Expand All @@ -8,6 +10,12 @@
from tests.helpers import get_available_devices, id_formatter, is_supported_on_hpu


@contextlib.contextmanager
def caplog_at_level(caplog, level, logger_name):
with caplog.at_level(level, logger=logger_name):
yield


class MockArgs:
def __init__(self, initial_data):
for key in initial_data:
Expand Down Expand Up @@ -453,46 +461,38 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu


@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_linear_warnings(device):
def test_4bit_linear_warnings(device, caplog):
dim1 = 64

with pytest.warns(UserWarning, match=r"inference or training"):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp)
with pytest.warns(UserWarning, match=r"inference."):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp)

with pytest.warns(UserWarning) as record:
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp)
assert any("inference or training" in msg for msg in caplog.messages)

caplog.clear()
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp)

assert len(record) == 2
assert any("inference." in msg for msg in caplog.messages)


@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_embedding_warnings(device):
def test_4bit_embedding_warnings(device, caplog):
num_embeddings = 128
default_block_size = 64

with pytest.warns(UserWarning, match=r"inference."):
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
net = bnb.nn.Embedding4bit(
num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4"
)
net.to(device)
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)
net(inp)
assert any("inference" in msg for msg in caplog.messages)


def test_4bit_embedding_weight_fsdp_fix(requires_cuda):
Expand Down