4-bit model used more RAM than bf16 in HF transformers
System Info
Google Colab 80GB A100 GPU Linux transformers-4.57.0.dev0 bitsandbytes 0.48.1
Reproduction
It was very weird that I tried to finetune Qwen3-VL-30B-A3B on Google Colab, and found that using 4-bit version actually raised VRAM OOM error, and if I use bf16 it runs fine.
If I added quantization_config=bnb_config it will have OOM
from transformers import BitsAndBytesConfig
import torch
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
)
# processor = AutoProcessor.from_pretrained("google/gemma-3-270m")
model = AutoModelForImageTextToText.from_pretrained(
"Qwen/Qwen3-VL-30B-A3B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"#, quantization_config=bnb_config
)
from peft import LoraConfig, get_peft_model
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.0,
r=16,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
]
)
from trl import SFTConfig
training_args = SFTConfig(
output_dir="qwen_anti_aesthetics_3b",
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=2,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=None,
optim="adamw_torch_fused",
learning_rate=2e-5,
weight_decay=0.001,
logging_steps=10,
eval_steps=500,
logging_strategy="steps",
eval_strategy="steps",
save_strategy="steps",
save_steps=500,
bf16=True,
warmup_ratio=0.02,
push_to_hub=True,
report_to="wandb",
remove_unused_columns=False,
dataloader_num_workers=12,
dataloader_prefetch_factor=4,
dataloader_pin_memory=True,
completion_only_loss=True,
lr_scheduler_type="cosine",
)
trainer = SFTTrainer(
model=model,
peft_config=peft_config,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
)
Expected behavior
4bit version should use less, at least not more RAM, than regular model
Dear @weathon,
thanks for the helpful issue! Would you be so kind to quickly downgrade to 0.47.0 to check if this is a new regression? There were some refactors in both bitsandbytes and transformers (I see you have a development version of Transformers installed, it's worth checking the released versions as well).
This would be very helpful, thanks!
System Info
Google Colab 80GB A100 GPU Linux transformers-4.57.0.dev0 bitsandbytes 0.48.1
Reproduction
It was very weird that I tried to finetune Qwen3-VL-30B-A3B on Google Colab, and found that using 4-bit version actually raised VRAM OOM error, and if I use bf16 it runs fine.
If I added quantization_config=bnb_config it will have OOM
from transformers import BitsAndBytesConfig import torch
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info
bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4', )
processor = AutoProcessor.from_pretrained("google/gemma-3-270m")
model = AutoModelForImageTextToText.from_pretrained( "Qwen/Qwen3-VL-30B-A3B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"#, quantization_config=bnb_config )
from peft import LoraConfig, get_peft_model
peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.0, r=16, bias="none", task_type="CAUSAL_LM", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", ] )
from trl import SFTConfig
training_args = SFTConfig( output_dir="qwen_anti_aesthetics_3b", num_train_epochs=5, per_device_train_batch_size=4, per_device_eval_batch_size=4, gradient_accumulation_steps=2, gradient_checkpointing_kwargs={"use_reentrant": False}, max_length=None, optim="adamw_torch_fused", learning_rate=2e-5, weight_decay=0.001, logging_steps=10, eval_steps=500, logging_strategy="steps", eval_strategy="steps", save_strategy="steps", save_steps=500, bf16=True, warmup_ratio=0.02, push_to_hub=True, report_to="wandb", remove_unused_columns=False, dataloader_num_workers=12, dataloader_prefetch_factor=4, dataloader_pin_memory=True, completion_only_loss=True, lr_scheduler_type="cosine", )
trainer = SFTTrainer( model=model, peft_config=peft_config, args=training_args, train_dataset=train_ds, eval_dataset=test_ds, )
Expected behavior
4bit version should use less, at least not more RAM, than regular model
same issue met, when using qwen3-30b-a3b-2507-it
torch says used only, but actually 35gb
Model device: cuda:0 CUDA device count: 1 GPU 0 memory allocated: 17241.29 MB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
model_name = "qwen3-30b-a3b"
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
load_in_4bit=True, # 4 bit quantization to reduce memory
)
# 输出模型装载分布信息
print(f"Model device: {model.device}")
if torch.cuda.is_available():
print(f"CUDA device count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"GPU {i} memory allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
else:
print("No CUDA device, running on CPU.")
# prepare the model input
prompt = "Give me a short introduction to large language model."
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# conduct text completion
start_time = time.time()
generated_ids = model.generate(
**model_inputs,
max_new_tokens=16384
)
end_time = time.time()
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
content = tokenizer.decode(output_ids, skip_special_tokens=True)
print("content:", content)
# 统计token per second
num_tokens = len(output_ids)
elapsed = end_time - start_time
if elapsed > 0:
print(f"Token per second: {num_tokens / elapsed:.2f}")
else:
print("Token per second: N/A (elapsed time too short)")
have you solved it?