transformers
transformers copied to clipboard
Support gradient checkpointing for ESM models
Would you please add gradient_checkpointing_enable() feature for ESM models?
These models currently are the best available pre-trained protein language models for researchers.
Many thanks.
cc @Rocketknight1
Any updates?
It's on the to-do list, but I'm afraid there are competing priorities at the moment!
Let's open it up for anyone in the community who might want to tackle it :)
Hi @amyeroberts @Rocketknight1 I would like to work on this
@sanjeevk-os Great! Once you have the code ready, open a PR and ping both @Rocketknight1 and me. Looking forward to reviewing!
Hi @sanjeevk-os, I actually took a look at the ESM code - it actually looks like some of the supports for gradient checkpointing are already there, in which case you just need to make a one-line change to set supports_gradient_checkpointing = True
Hi @Rocketknight1 Thank you for taking a look. I also noticed that the ESM model has the create_custom_forward passed to torch checkpoint function. I will do some more checks and will raise a PR soon.
Hi @sanjeevk-os - we're getting even more requests for this, so we'd like to try to add it soon! If you're having trouble, just let us know. We can take over the PR internally to try to get it through, and we appreciate your effort regardless.
This issue has now been resolved - thank you to @sanjeevk-os for the very clean PR!