accelerate
accelerate copied to clipboard
load_checkpoint_and_dispatch OOMs
System Info
- `Accelerate` version: 0.18.0
- Platform: Linux-5.15.0-1015-aws-x86_64-with-glibc2.31
- Python version: 3.9.16
- Numpy version: 1.24.1
- PyTorch version (GPU?): 2.0.0.dev20230202+cu116 (False)
- `Accelerate` default config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- use_cpu: False
- num_processes: 8
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- deepspeed_config: {'deepspeed_config_file': 'deepspeed_z3.json', 'zero3_init_flag': True}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
) - [X] My own task or dataset (give details below)
Reproduction
import sys
import torch
from transformers import LlamaForCausalLM
from accelerate import Accelerator
import deepspeed
from accelerate import load_checkpoint_and_dispatch
def main():
accelerator = Accelerator()
sft_model = LlamaForCausalLM.from_pretrained('llama-7b/')
sft_model = load_checkpoint_and_dispatch(sft_model, 'models/best/pytorch_model.bin', device_map="auto")
opt = torch.optim.Adam(sft_model.parameters(), lr=1e-5)
(sft_model, opt, ) = accelerator.prepare(sft_model, opt,)
sft_model.train()
accelerator.train()
if __name__ == "__main__":
sys.exit(main())
Expected behavior
This OOMs on a node with 8 80GB A100, and it should not. Also tried device_map = balanced and balanced_low_0, and I get the same OOM.