fully_shard() for huggingface model: pytorch caches too much GPU memory
Dear Community,
I'm working on fine-tuning the Qwen2-VL model using fully_shard() and wrote a script for it. However, I noticed that GPU memory usage stays high (around 50GB to 60GB) even as I scale up the number of GPUs. Besides, it will run into OOM when I try to fine tune 72B model with 128 GPUs.
I'm wondering if there might be any issues with my code or configuration. I'd really appreciate any insights or suggestions you might have. Thanks in advance!
My code:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, AutoModelForVision2Seq, AutoConfig
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import numpy as np
from PIL import Image
import io
import logging
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.device_mesh import init_device_mesh
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer, Qwen2VLVisionBlock
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
fully_shard,
MixedPrecisionPolicy,
)
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# init dist
distributed_backend = "nccl" # gloo for cpu
dist.init_process_group(distributed_backend)
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
# model_name = "Qwen/Qwen2-VL-2B-Instruct"
# revision = "895c3a49bc3fa70a340399125c650a463535e71c"
model_name = "Qwen/Qwen2-VL-7B-Instruct"
revision = "a28a094eb66a9f2ac70eef346f040d8a79977472"
# model_name = "Qwen/Qwen2-VL-72B-Instruct"
# revision = "f9b556a74d58e6d9915f73227c21045c87342b42"
dataset_id = "HuggingFaceM4/ChartQA"
processor = Qwen2VLProcessor.from_pretrained(model_name,
revision=revision,
)
# Configuration
class Config:
dataset_id = "HuggingFaceM4/ChartQA"
output_dir = "/tmp_ckpt"
batch_size = 2
num_epochs = 3
learning_rate = 5e-5
max_seq_length = 512
lora_rank = 32
lora_alpha = 64
lora_dropout = 0.1
device = "cuda" if torch.cuda.is_available() else "cpu"
system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""
def format_data(sample):
return [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "image",
"image": sample["image"],
},
{
"type": "text",
"text": sample["query"],
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["label"][0]}],
},
]
# Training function
def train_model(model, train_loader, optimizer, config):
model.train()
total_steps = len(train_loader) * config.num_epochs
step = 0
scaler = torch.amp.GradScaler("cuda", enabled=True)
for epoch in range(config.num_epochs):
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
inputs, labels = batch
inputs = inputs.to(config.device)
labels = labels.to(config.device)
# Mixed precision training
loss = model(**inputs, labels=labels).loss
loss.backward() # no scaler
optimizer.step()
optimizer.zero_grad()
step += 1
logger.info(f"Epoch {epoch+1}/{config.num_epochs}, Step {step}/{total_steps}, Loss: {loss.item():.4f}")
del loss
# Create a data collator to encode text and image pairs
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(example, tokenize=False) for example in examples
] # Prepare texts for processing
image_inputs = [process_vision_info(example)[0] for example in examples] # Process the images to extract inputs
# Tokenize the texts and process the images
batch = processor(
text=texts, images=image_inputs, return_tensors="pt", padding=True
) # Encode texts and images into tensors
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone() # Clone input IDs for labels
labels[labels == processor.tokenizer.pad_token_id] = -100 # Mask padding tokens in labels
# Ignore the image token index in the loss computation (model specific)
if isinstance(processor, Qwen2VLProcessor): # Check if the processor is Qwen2VLProcessor
image_tokens = [151652, 151653, 151655] # Specific image token IDs for Qwen2VLProcessor
else:
image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] # Convert image token to ID
# Mask image token IDs in the labels
for image_token_id in image_tokens:
labels[labels == image_token_id] = -100 # Mask image token IDs in labels
return batch, labels
# Main function
def main():
config = Config()
# Load model and processor
logger.info("Loading model and processor...")
hf_config = AutoConfig.from_pretrained(
model_name,
revision=revision,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(hf_config, torch_dtype=torch.bfloat16)
mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
output_dtype=torch.bfloat16,
cast_forward_inputs=True)
offload_policy = CPUOffloadPolicy(pin_memory=False)
# apply FSDP2
device_mesh = init_device_mesh("cuda", (world_size,))
for module in model.modules():
if isinstance(module, Qwen2VLDecoderLayer):
fully_shard(module,
mesh=device_mesh,
reshard_after_forward=True,
mp_policy=mp_policy,
# offload_policy=offload_policy,
)
model = fully_shard(model,
mesh=device_mesh,
reshard_after_forward=True,
mp_policy=mp_policy,
# offload_policy=offload_policy,
)
model.to_empty(device='cuda')
model_state_dict = model.state_dict()
model_dir = "/cache/fsdp_test/72B_8_files"
# load qwen2-vl model
dcp.load(
state_dict=model_state_dict,
checkpoint_id=model_dir,
planner=DefaultLoadPlanner(allow_partial_load=True),
)
model = model.to(torch.bfloat16).cuda()
# Load dataset
logger.info("Loading dataset...")
train_dataset, eval_dataset, test_dataset = load_dataset(
config.dataset_id, split=['train[:10%]', 'val[:10%]', 'test[:10%]'])
train_dataset = [format_data(sample) for sample in train_dataset]
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=1,
collate_fn=collate_fn,
shuffle=True,
)
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
# Create output directory
os.makedirs(config.output_dir, exist_ok=True)
# Train
logger.info("Starting training...")
train_model(model, train_dataloader, optimizer, config)
if __name__ == "__main__":
main()
destroy_process_group()
logger.info("Training completed.")
Running command:
torchrun --nnodes=2 --nproc_per_node=8 qwenvl_train_fsdp.py
torchrun --nnodes=4 --nproc_per_node=8 qwenvl_train_fsdp.py
torchrun --nnodes=8 --nproc_per_node=8 qwenvl_train_fsdp.py
The following is the screenshot of the result of nvidia-smi:
16 GPU:
32 GPU:
64 GPU:
@awgu @mori360 @fegin @yzhangcs @tianyu-l
@mingdianliu could it be possible that the activations dominate the memory usage under such a setting? Like a 7B model, even if we use float32, then the parameters + gradients + optimizer states is like 112 GB and with 16 GPU, each GPU will get roughly 7GB. If you freeze some modules for fine-funing, this number fewer. Same for 72B model issue, you will have to apply other techniques to reduce the memory consumption from activations, like TP or activation checkpointing.
@mingdianliu could it be possible that the activations dominate the memory usage under such a setting? Like a 7B model, even if we use float32, then the parameters + gradients + optimizer states is like 112 GB and with 16 GPU, each GPU will get roughly 7GB. If you freeze some modules for fine-funing, this number fewer. Same for 72B model issue, you will have to apply other techniques to reduce the memory consumption from activations, like TP or activation checkpointing.
Hi @fegin
Thanks for your follow-up. I found it is due to pytorch cache. The allocated and reserved GPU memory is quite small while the cached GPU memory is even higher than 50GB. I had a shoot on torch.cuda.empty_cache() after each training iteration but the GPU memory cache during each training iteration is also high (~20GB). I wonder if it is a bug of FSDP2. If not, is there any method that can mitigate this issue?
Caching is not an issue because those memory will be reused for other tensor allocation. But this will not cause OOM because when new tensors are created, PyTorch will first find some empty caching memory for the tensors. And only if there is no available caching space, will PyTorch ask CUDA to give more. And if CUDA cannot give enough memory, then OOM will happen.
So, if you are not seeing OOM but only seeing high cache memory, that should not be an issue. You actually are seeing OOM, you can try to export this environment variable PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True. If this doesn't help, you need to reduce the memory usage (reducing batch size, model size, using TP, activation chekcpointing...).