axolotl
axolotl copied to clipboard
Switch to parallel FFD bin packing algorithm (closes #1492)
Description
Replace the existing sample packing algorithm with a parallel implementation of first-fit-decreasing.
Motivation and Context
I noticed recently that we could get denser sample packing with a different algorithm. Looking into it more, FFD performs just as well and is much faster than the heuristic I had 😅.
We can run FFD in parallel without losing much performance by packing samples in groups rather than all at once. On an i9-14900k, it takes 2.2s to pack 1M samples with 99.7% efficiency (current multipack.py is 91.7% in 0.32s.)
I removed the length estimates around packing in favor of just counting the batches, but let me know if I should add that back in. Two new config options are added: sample_packing_group_size controls the the number of samples packed by each process, and sample_packing_bin_size sets the number of samples that can be placed in one pack (may need to be increased for large context lengths.)
How has this been tested?
Tests have been updated to verify that packing is correct. Training appears to run the same, just with fewer steps.
It seems reasonable that sorting the items in FFD would interfere with shuffling between epochs, but I haven't been able to find any evidence of that being the case. Testing against a few similarity metrics shows that even when we do the packing at once in one group, shuffling still generates a mostly new set of packs.
Screenshots
Some performance checks below for 1M items.
I removed the length estimates around packing in favor of just counting the batches, but let me know if I should add that back in.
I need to do some checking, but the estimates exist due to different processes getting different splits of data, so the actual count of packed samples can vary from process to process. When this happens, you get one process thinking it needs to run another step, but another process thinking it's done and they get out of sync. The estimate was the most sane way I could come up with having each process come up with a deterministic length. I'm open to other ideas to working around this.
Could we generate all the packs, and then evenly split those up (like in the updated multipack.py)? I think each rank should then get an exact number of batches and stay in sync.
Could we generate all the packs, and then evenly split those up (like in the updated multipack.py)? I think each rank should then get an exact number of batches and stay in sync.
Perhaps we could do something like dispatch_batches=True to only run the packing on rank 0. I'm not 100% certain of the implications though
Hey, this is very interesting. Should there be some full run comparisons to make sure that there is no loss in performance?
Perhaps we could do something like dispatch_batches=True to only run the packing on rank 0. I'm not 100% certain of the implications though
Gotcha, for now I'll keep this PR simple by leaving the packing estimates in. Ready for another look.
Hey, this is very interesting. Should there be some full run comparisons to make sure that there is no loss in performance?
Yeah definitely, once the code is greenlit/finalized I'll rent an instance to test it in a distributed setup.
Hey @dsesclei we cherry picked and merged your fixes in #1619. Thanks! Would love to give you a shoutout if you're on twitter or discord and could share your handle. thanks!
Thanks for getting this in Wing! No handles to give, but I appreciate it
Thanks @dsesclei, I ended up having to revert the change b/c the loss was off by an order of magnitude. I need to dig into what the multipack sampler is outputting another time to see if there is something obvious that it is doing differently
Oh gotcha, I'll look into it