Gradient checkpointing throws use_reentrant warning on PyTorch 2.1
System Info
transformersversion: 4.36.2- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
- Python version: 3.10.13
- Huggingface_hub version: 0.19.4
- Safetensors version: 0.4.0
- Accelerate version: 0.25.0
- Accelerate config: - compute_environment: LOCAL_MACHINE - distributed_type: DEEPSPEED - mixed_precision: bf16 - use_cpu: False - debug: False - num_processes: 8 - machine_rank: 0 - num_machines: 1 - rdzv_backend: static - same_network: True - main_training_function: main - deepspeed_config: {'gradient_accumulation_steps': 1, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': True, 'zero3_save_16bit_model': False, 'zero_stage': 3} - downcast_bf16: no - tpu_use_cluster: False - tpu_use_sudo: False - tpu_env: []
- PyTorch version (GPU?): 2.1.1+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): 0.7.5 (cpu)
- Jax version: 0.4.21
- JaxLib version: 0.4.21
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Yes
Who can help?
@ArthurZucker @younesbelkada
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Training any text model with gradient checkpointing enabled on PyTorch 2.1 and higher produces this warning:
/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: Warning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
This can be resolved by manually monkey-patching the model code with use_reentrant=True, eg. like so:
hidden_states, self_attns, decoder_cache = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
is_padded_inputs,
use_reentrant=True,
)
This is caused by an upstream change in PyTorch:
https://medium.com/pytorch/how-activation-checkpointing-enables-scaling-up-training-deep-learning-models-7a93ae01ff2d
Expected behavior
No warning should be written
Thanks for raising! given that we had #27020, this should be fairly easy to fix! cc @younesbelkada
@ArthurZucker is this still outstanding?
Will merge the PR today
Which version start this fixed? Am using 3.47.2 still get this error.
4.39.3 till get this warning.
4.39.3 till get this warning.
For my case, model.gradient_checkpointing_enable() fix it. maybe you can try @lucasjinreal
I'm using transformers==4.43.3, and still getting errors when trying to use the Trainer API with gradient_checkpointing=True.
I'm using transformers==4.43.3, and still getting errors when trying to use the
TrainerAPI withgradient_checkpointing=True.
Me too.. Try to use model.gradient_checkpointing_enable() and do not specify gradient_checkpointing=True in huggingface Trainer API. It solved my problem.