DeepSpeed
DeepSpeed copied to clipboard
[BUG] zero3 hang during inference, need to detach part of computational graph, .detach()/torch.no_grad do not work.
Describe the bug I am training a video-llm model, where I encode log videos with a varying number of forward passes to avoid OOM issues. I would like to use ZeRO3, but using a part of the model a different number of times causes the computational graph to be different across nodes/GPUs, and ZeRO3 to hang.
I don't need to compute the gradients over the video encoder, and would like to just completely remove it from the computation graph. I have tried: (1) freezing the encoder, (2) applying the encoder with a @torch.no_grad(), and (3) using .detach() on the output tensors, but to no avail.
How can I effectively `remove' a part of the model checkpointing, if at all possible? I can't pre-encode entire videos as this is too memory heavy for my setup.
To Reproduce Steps to reproduce the behavior: Take any model, apply some part of it multiple times. e.g;
import torch
import torch.nn as nn
from transformers import Trainer, TrainingArguments
import clip
class VideoToEmbeddingModel(nn.Module):
def __init__(self, clip_model_name="ViT-B/32", mlp_input_dim=512, mlp_hidden_dim=128, mlp_output_dim=8):
super().__init__()
self.clip_model, _ = clip.load(clip_model_name)
self.mlp = nn.Sequential(
nn.Linear(mlp_input_dim, mlp_hidden_dim),
nn.ReLU(),
nn.Linear(mlp_hidden_dim, mlp_output_dim)
)
def forward(self, video_frames):
# Initialize list to store encoded frames
encoded_frames = []
# Loop through video frames
for frame in video_frames:
# Encode frame without gradients
with torch.no_grad():
encoded_frame = self.clip_model.encode_image(frame)
# Append encoded frame to list
encoded_frames.append(encoded_frame)
# Average encoded frames
averaged_embedding = torch.stack(encoded_frames).mean(dim=0)
# Detach output tensor from computational graph
detached_embedding = averaged_embedding.detach()
# Pass the detached embedding through the MLP
output = self.mlp(detached_embedding)
return output
# Example usage:
model = VideoToEmbeddingModel()
# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
evaluation_strategy="epoch",
learning_rate=5e-5,
save_total_limit=2,
save_steps=500,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
save_on_each_node=True,
fp16=True,
deepspeed="ds_config.json", # Your DeepSpeed config file
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=your_train_dataset, # Your training dataset
eval_dataset=your_eval_dataset, # Your evaluation dataset
compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
)
# Train the model
trainer.train()
use this zero3 json:
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
Expected behavior I would like to find a way to still be able to encode long videos, while using zero3 as when I train larger LLMs, zero3 becomes very important.
ds_report output
Please run ds_report to give us details about your setup.
NCCL/hanging.
Screenshots If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
- Ubuntu 18.04
- 8 nodes, 8 A100s each
Launcher context Using deepspeed launcher with hostfile
Docker context Are you using a specific docker image that you can share?
Additional context Add any other context about the problem here.