transformers icon indicating copy to clipboard operation
transformers copied to clipboard

activation_checkpointing error when using --fsdp

Open getao opened this issue 1 year ago • 6 comments

System Info

transformers == 4.36.2 pytorch == 2.1.0

Who can help?

When using deepspeed to enable activation checkpointing, everything goes well. However, when I switch to torchrun with the native pytorch fsdp integrated into the huggingface: https://huggingface.co/docs/transformers/main/main_classes/trainer#transformers.TrainingArguments.fsdp

I can't run the training process properly with the following errors:

File "/workspace/training_script.py", line 77, in train_model
train_result = trainer.train(resume_from_checkpoint=checkpoint)

File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1537, in train
return inner_training_loop(

File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1854, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)

File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2744, in training_step
self.accelerator.backward(loss)

File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 1905, in backward
loss.backward(**kwargs)

File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(

File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1065, in unpack_hook
args = ctx.get_args(ctx.saved_tensors)

File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1075, in unpack_hook
frame.check_recomputed_tensors_match(gid)

File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 850, in check_recomputed_tensors_match
raise CheckpointError(

torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.

tensor at position 13:
saved metadata: {'shape': torch.Size([1, 3112, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
recomputed metadata: {'shape': torch.Size([1, 9336, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}

tensor at position 14:
saved metadata: {'shape': torch.Size([1, 3112, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
recomputed metadata: {'shape': torch.Size([1, 9336, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}

The model I used is Llama-2 and I didn't change its forward function and use Trainer to train it. I wonder if there is something wrong with activation_checkpointing (enabling it in the fsdp_config.json) feature used together with --fsdp.

Thank you

Information

  • [ ] The official example scripts
  • [x] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [x] My own task or dataset (give details below)

Reproduction

Training Llama using Trainer with the following arguments:

--fsdp shard_grad_op --fsdp_config fsdp_config.json (where activation_checkpointing is set to true)

Expected behavior

Properly running the training process with memory saved.

getao avatar Jan 14 '24 22:01 getao

Hi @getao, thanks for raising an issue!

Could you provide a minimal reproducer for this error. Specifically the full CLI command used to launch the training job?

cc @muellerzr @pacman100

amyeroberts avatar Jan 15 '24 12:01 amyeroberts

Sure.

def main():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()


    data_prefix = data_args.data_path
    
    train_file = f"{data_prefix}.train.json" # text data for language modeling (the next word prediction task)
    eval_file = f"{data_prefix}.eval.json"
    
    dataset = load_dataset("json", data_files={"train": train_file, "eval": eval_file})
    train_dataset = dataset["train"]
    eval_dataset = dataset["eval"]
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    train_dataset = train_dataset.map(tokenize_function, batched=True, fn_kwargs={"tokenizer": tokenizer, "max_seq_length": data_args.max_seq_length, "add_special_tokens": data_args.add_special_tokens}, load_from_cache_file=True)
    eval_dataset = eval_dataset.map(tokenize_function, batched=True, fn_kwargs={"tokenizer": tokenizer, "max_seq_length": data_args.max_seq_length, "add_special_tokens": data_args.add_special_tokens, "dev": True}, load_from_cache_file=True) 

    model_download_flag = False
    while model_download_flag is False: 
        try:
            model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, torch_dtype=torch.float16 if training_args.bf16 is False else torch.bfloat16, use_flash_attention_2=True, resume_download=True) # Llama-2
            model_download_flag = True
        except Exception as e:
            print(e)

    train_model(model, train_dataset, eval_dataset, training_args, data_collator)

main()

The CLI looks as follows -- nothing special except 2 --fsdp related flags:

torchrun --nproc-per-node=6 train_script.py --adam_beta2 0.95 --adam_epsilon 1e-6 --num_train_epochs $epoch --per_device_train_batch_size $batch --per_device_eval_batch_size $batch --gradient_accumulation_steps 1 --fsdp shard_grad_op --fsdp_config my_fsdp_config_path \
         --learning_rate $lr --warmup_steps $warmup --max_grad_norm 1.0 --seed $seed --data_seed $seed --logging_steps 10 --save_strategy 'no' --evaluation_strategy 'steps' --eval_steps $eval_steps \
         --save_steps $save_steps --bf16  --output_dir $OUT_DIR --logging_dir $OUT_DIR --data_path $INPUT_DATA | tee $OUT_DIR/train.log

getao avatar Jan 15 '24 19:01 getao

I am having the same issue - even observing an increase in the shape of the tensors that is an exact multiple of the number of GPUs used (like at index position 1 in OP's stack trace).

Additionally, if I freeze the pretrained part of my model as opposed to leaving everything trainable, the error becomes:

.venv/lib/python3.10/site-packages/torch/distributed/utils.py", line 147, in _p_assert
    raise AssertionError(s)
AssertionError: Prefetching is only supported in (<HandleTrainingState.BACKWARD_PRE: 3>, <HandleTrainingState.BACKWARD_POST: 4>, <HandleTrainingState.FORWARD: 2>) but currently in HandleTrainingState.IDLE

vtien avatar Feb 27 '24 00:02 vtien

I'm encountering the same issue. Are there any updates or workaround on this? Is there any way I can help?

wulu473 avatar Mar 21 '24 15:03 wulu473

Gentle ping @muellerzr @pacman100

amyeroberts avatar Mar 22 '24 13:03 amyeroberts

I was able to resolve this error by settingmodel.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': True}), as found in https://github.com/Lightning-AI/pytorch-lightning/issues/19267

Still unsure of the root cause but this fixed it for me. Also if anyone encounters the Prefetching error pasted above, play around with the layers in your auto_wrap_policy, as adding unnecessarily small submodules of layers reproduced this error for me and was resolved by being more careful about what was included here

vtien avatar Mar 22 '24 14:03 vtien

@vtien @getao Do you get the CheckpointError with use_flash_attention_2=False?

irenedea avatar Apr 18 '24 23:04 irenedea

@vtien @getao Do you get the CheckpointError with use_flash_attention_2=False?

I didn't try use_flash_attention_2=False

getao avatar Apr 26 '24 08:04 getao

@vtien Genius! Using the REENTRANT checkpointing option also works for me. As described in https://pytorch.org/docs/stable/checkpoint.html, these two types of checkpointing methods do have certain differences. I will try to see if there are performance differences.

shun-zheng avatar May 16 '24 15:05 shun-zheng

@irenedea I have tried to use 'attn_implementation=eager', still triggering a smilar error. Interestingly, using the REENTRANT checkpointing strategy does not help me for this case. The bug is really weird.

shun-zheng avatar May 16 '24 15:05 shun-zheng