transformers
transformers copied to clipboard
More robust tests required for gradient checkpointing
As flagged in this comment from #31945, the current test suite for gradient checkpointing may fail to spot subtle bugs. For example, the bug flagged in #31028 was not detected by any test, because the success criterion is simply that gradients are not None
for each parameter that has requires_grad=True
.
Proposal: every function whose behavior is affected by whether gradient checkpointing is enabled or not, such as a model's forward()
method, should be tested twice: once with GC enabled and once with GC disabled, making sure that the outputs are exactly the same given certain inputs and configuration parameters.