DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] zero3 hang during inference, need to detach part of computational graph, .detach()/torch.no_grad do not work.

Open orrzohar opened this issue 1 year ago • 3 comments

Describe the bug I am training a video-llm model, where I encode log videos with a varying number of forward passes to avoid OOM issues. I would like to use ZeRO3, but using a part of the model a different number of times causes the computational graph to be different across nodes/GPUs, and ZeRO3 to hang.

I don't need to compute the gradients over the video encoder, and would like to just completely remove it from the computation graph. I have tried: (1) freezing the encoder, (2) applying the encoder with a @torch.no_grad(), and (3) using .detach() on the output tensors, but to no avail.

How can I effectively `remove' a part of the model checkpointing, if at all possible? I can't pre-encode entire videos as this is too memory heavy for my setup.

To Reproduce Steps to reproduce the behavior: Take any model, apply some part of it multiple times. e.g;

import torch
import torch.nn as nn
from transformers import Trainer, TrainingArguments
import clip

class VideoToEmbeddingModel(nn.Module):
    def __init__(self, clip_model_name="ViT-B/32", mlp_input_dim=512, mlp_hidden_dim=128, mlp_output_dim=8):
        super().__init__()
        self.clip_model, _ = clip.load(clip_model_name)
        self.mlp = nn.Sequential(
            nn.Linear(mlp_input_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_output_dim)
        )

    def forward(self, video_frames):
        # Initialize list to store encoded frames
        encoded_frames = []

        # Loop through video frames
        for frame in video_frames:
            # Encode frame without gradients
            with torch.no_grad():
                encoded_frame = self.clip_model.encode_image(frame)
            # Append encoded frame to list
            encoded_frames.append(encoded_frame)

        # Average encoded frames
        averaged_embedding = torch.stack(encoded_frames).mean(dim=0)

        # Detach output tensor from computational graph
        detached_embedding = averaged_embedding.detach()

        # Pass the detached embedding through the MLP
        output = self.mlp(detached_embedding)
        return output

# Example usage:
model = VideoToEmbeddingModel()

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    save_total_limit=2,
    save_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    save_on_each_node=True,
    fp16=True,
    deepspeed="ds_config.json",  # Your DeepSpeed config file
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=your_train_dataset,  # Your training dataset
    eval_dataset=your_eval_dataset,  # Your evaluation dataset
    compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
)

# Train the model
trainer.train()

use this zero3 json:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "train_micro_batch_size_per_gpu": "auto",
    "train_batch_size": "auto",
    "gradient_accumulation_steps": "auto",
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }
}

Expected behavior I would like to find a way to still be able to encode long videos, while using zero3 as when I train larger LLMs, zero3 becomes very important.

ds_report output Please run ds_report to give us details about your setup. NCCL/hanging.

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • Ubuntu 18.04
  • 8 nodes, 8 A100s each

Launcher context Using deepspeed launcher with hostfile

Docker context Are you using a specific docker image that you can share?

Additional context Add any other context about the problem here.

orrzohar avatar Aug 25 '24 17:08 orrzohar