accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

load_checkpoint_and_dispatch OOMs

Open sriniiyer opened this issue 1 year ago • 4 comments

System Info

- `Accelerate` version: 0.18.0
- Platform: Linux-5.15.0-1015-aws-x86_64-with-glibc2.31
- Python version: 3.9.16
- Numpy version: 1.24.1
- PyTorch version (GPU?): 2.0.0.dev20230202+cu116 (False)
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: DEEPSPEED
        - use_cpu: False
        - num_processes: 8
        - machine_rank: 0
        - num_machines: 1
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - deepspeed_config: {'deepspeed_config_file': 'deepspeed_z3.json', 'zero3_init_flag': True}
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • [X] My own task or dataset (give details below)

Reproduction

import sys
import torch
from transformers import LlamaForCausalLM
from accelerate import Accelerator
import deepspeed
from accelerate import load_checkpoint_and_dispatch

def main():
    accelerator = Accelerator()
    sft_model = LlamaForCausalLM.from_pretrained('llama-7b/')
    sft_model = load_checkpoint_and_dispatch(sft_model, 'models/best/pytorch_model.bin', device_map="auto")

    opt = torch.optim.Adam(sft_model.parameters(), lr=1e-5)
    (sft_model, opt, ) = accelerator.prepare(sft_model, opt,)

    sft_model.train()
    accelerator.train()

if __name__ == "__main__":
    sys.exit(main())

Expected behavior

This OOMs on a node with 8 80GB A100, and it should not. Also tried device_map = balanced and balanced_low_0, and I get the same OOM.

sriniiyer avatar May 18 '23 18:05 sriniiyer