Easy-Transformer
Easy-Transformer copied to clipboard
Load state dict with assign to avoid OOMs
Description
I personally keep getting OOMs when trying to load llama3 8B. I narrowed down the offending code to load_and_process_state_dict
, then realized Neel offered a suggestion in #480 that isn't implemented yet.
So, a simple change: expose an arg that allows users to request whether to use use assign=True
(if the installed PyTorch version is recent enough).
(It broke tests when I tried to enable this behaviour by default, probably because the state dict gets deallocated or something. So instead I've added it as an arg.)
Fixes #480. (issue)
Type of change
Bug fix or new feature, depending on perspective. Should allow loading more models without needing a ton of memory.
Screenshots
Please attach before and after screenshots of the change if applicable.
Checklist:
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have not rewritten tests relating to key interfaces which would affect backward compatibility