litgpt
litgpt copied to clipboard
Memory usage improvements
See posted comments for in-depth explanations.
Memory usage was gathered with
with torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=True) as p:
# the training loop
...
from torch.cuda._memory_viz import profile_plot
with open('memory.html', 'w') as f:
f.write(profile_plot(p))
and setting
- max_len = max(len(s) for s in input_ids) if fabric.device.type != "xla" else max_seq_length
+ max_len = max_seq_length
With the changes in this PR, running python finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision 16-true with max_seq_length=1079 goes from 23.10 to 22.82 GB maximum memory allocated.
Huge thanks to @robieta for helping debug the backward memory spike