Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

Load state dict with assign to avoid OOMs

Open cyber-chris opened this issue 5 months ago • 0 comments

Description

I personally keep getting OOMs when trying to load llama3 8B. I narrowed down the offending code to load_and_process_state_dict, then realized Neel offered a suggestion in #480 that isn't implemented yet.

So, a simple change: expose an arg that allows users to request whether to use use assign=True (if the installed PyTorch version is recent enough).

(It broke tests when I tried to enable this behaviour by default, probably because the state dict gets deallocated or something. So instead I've added it as an arg.)

Fixes #480. (issue)

Type of change

Bug fix or new feature, depending on perspective. Should allow loading more models without needing a ton of memory.

Screenshots

Please attach before and after screenshots of the change if applicable.

Checklist:

  • [ ] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [ ] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] New and existing unit tests pass locally with my changes
  • [ ] I have not rewritten tests relating to key interfaces which would affect backward compatibility

cyber-chris avatar Sep 15 '24 11:09 cyber-chris