litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Memory usage improvements

Open carmocca opened this issue 2 years ago • 0 comments

See posted comments for in-depth explanations.

Memory usage was gathered with

    with torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=True) as p:
        # the training loop
        ...

    from torch.cuda._memory_viz import profile_plot
    with open('memory.html', 'w') as f:
        f.write(profile_plot(p))

and setting

-    max_len = max(len(s) for s in input_ids) if fabric.device.type != "xla" else max_seq_length
+    max_len = max_seq_length

With the changes in this PR, running python finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision 16-true with max_seq_length=1079 goes from 23.10 to 22.82 GB maximum memory allocated.

Huge thanks to @robieta for helping debug the backward memory spike

carmocca avatar Jun 21 '23 01:06 carmocca