transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Gradient checkpointing throws use_reentrant warning on PyTorch 2.1

Open rosario-purple opened this issue 1 year ago • 2 comments

System Info

  • transformers version: 4.36.2
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.0
  • Accelerate version: 0.25.0
  • Accelerate config: - compute_environment: LOCAL_MACHINE - distributed_type: DEEPSPEED - mixed_precision: bf16 - use_cpu: False - debug: False - num_processes: 8 - machine_rank: 0 - num_machines: 1 - rdzv_backend: static - same_network: True - main_training_function: main - deepspeed_config: {'gradient_accumulation_steps': 1, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': True, 'zero3_save_16bit_model': False, 'zero_stage': 3} - downcast_bf16: no - tpu_use_cluster: False - tpu_use_sudo: False - tpu_env: []
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.7.5 (cpu)
  • Jax version: 0.4.21
  • JaxLib version: 0.4.21
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

@ArthurZucker @younesbelkada

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

Training any text model with gradient checkpointing enabled on PyTorch 2.1 and higher produces this warning:

/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: Warning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

This can be resolved by manually monkey-patching the model code with use_reentrant=True, eg. like so:

                hidden_states, self_attns, decoder_cache = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                    is_padded_inputs,
                    use_reentrant=True,
                )

This is caused by an upstream change in PyTorch:

https://medium.com/pytorch/how-activation-checkpointing-enables-scaling-up-training-deep-learning-models-7a93ae01ff2d

Expected behavior

No warning should be written

rosario-purple avatar Jan 16 '24 16:01 rosario-purple

Thanks for raising! given that we had #27020, this should be fairly easy to fix! cc @younesbelkada

ArthurZucker avatar Jan 16 '24 16:01 ArthurZucker

@ArthurZucker is this still outstanding?

rosario-purple avatar Feb 16 '24 21:02 rosario-purple

Will merge the PR today

ArthurZucker avatar Feb 19 '24 03:02 ArthurZucker

Which version start this fixed? Am using 3.47.2 still get this error.

lucasjinreal avatar Apr 09 '24 02:04 lucasjinreal

4.39.3 till get this warning.

huangganggui avatar Apr 11 '24 06:04 huangganggui

4.39.3 till get this warning.

For my case, model.gradient_checkpointing_enable() fix it. maybe you can try @lucasjinreal

huangganggui avatar Apr 11 '24 07:04 huangganggui

I'm using transformers==4.43.3, and still getting errors when trying to use the Trainer API with gradient_checkpointing=True.

ankush13r avatar Aug 04 '24 15:08 ankush13r

I'm using transformers==4.43.3, and still getting errors when trying to use the Trainer API with gradient_checkpointing=True.

Me too.. Try to use model.gradient_checkpointing_enable() and do not specify gradient_checkpointing=True in huggingface Trainer API. It solved my problem.

BigDataMLexplorer avatar Aug 11 '24 15:08 BigDataMLexplorer