axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

Switch to parallel FFD bin packing algorithm (closes #1492)

Open dsesclei opened this issue 1 year ago • 5 comments

Description

Replace the existing sample packing algorithm with a parallel implementation of first-fit-decreasing.

Motivation and Context

I noticed recently that we could get denser sample packing with a different algorithm. Looking into it more, FFD performs just as well and is much faster than the heuristic I had 😅.

We can run FFD in parallel without losing much performance by packing samples in groups rather than all at once. On an i9-14900k, it takes 2.2s to pack 1M samples with 99.7% efficiency (current multipack.py is 91.7% in 0.32s.)

I removed the length estimates around packing in favor of just counting the batches, but let me know if I should add that back in. Two new config options are added: sample_packing_group_size controls the the number of samples packed by each process, and sample_packing_bin_size sets the number of samples that can be placed in one pack (may need to be increased for large context lengths.)

How has this been tested?

Tests have been updated to verify that packing is correct. Training appears to run the same, just with fewer steps.

It seems reasonable that sorting the items in FFD would interfere with shuffling between epochs, but I haven't been able to find any evidence of that being the case. Testing against a few similarity metrics shows that even when we do the packing at once in one group, shuffling still generates a mostly new set of packs.

Screenshots

Some performance checks below for 1M items.

group_size_vs_excess bin_size_vs_excess

dsesclei avatar Apr 11 '24 02:04 dsesclei

I removed the length estimates around packing in favor of just counting the batches, but let me know if I should add that back in.

I need to do some checking, but the estimates exist due to different processes getting different splits of data, so the actual count of packed samples can vary from process to process. When this happens, you get one process thinking it needs to run another step, but another process thinking it's done and they get out of sync. The estimate was the most sane way I could come up with having each process come up with a deterministic length. I'm open to other ideas to working around this.

winglian avatar Apr 11 '24 04:04 winglian

Could we generate all the packs, and then evenly split those up (like in the updated multipack.py)? I think each rank should then get an exact number of batches and stay in sync.

dsesclei avatar Apr 11 '24 19:04 dsesclei

Could we generate all the packs, and then evenly split those up (like in the updated multipack.py)? I think each rank should then get an exact number of batches and stay in sync.

Perhaps we could do something like dispatch_batches=True to only run the packing on rank 0. I'm not 100% certain of the implications though

winglian avatar Apr 16 '24 23:04 winglian

Hey, this is very interesting. Should there be some full run comparisons to make sure that there is no loss in performance?

NanoCode012 avatar Apr 18 '24 14:04 NanoCode012

Perhaps we could do something like dispatch_batches=True to only run the packing on rank 0. I'm not 100% certain of the implications though

Gotcha, for now I'll keep this PR simple by leaving the packing estimates in. Ready for another look.

Hey, this is very interesting. Should there be some full run comparisons to make sure that there is no loss in performance?

Yeah definitely, once the code is greenlit/finalized I'll rent an instance to test it in a distributed setup.

dsesclei avatar Apr 20 '24 00:04 dsesclei

Hey @dsesclei we cherry picked and merged your fixes in #1619. Thanks! Would love to give you a shoutout if you're on twitter or discord and could share your handle. thanks!

winglian avatar May 23 '24 21:05 winglian

Thanks for getting this in Wing! No handles to give, but I appreciate it

dsesclei avatar May 29 '24 22:05 dsesclei

Thanks @dsesclei, I ended up having to revert the change b/c the loss was off by an order of magnitude. I need to dig into what the multipack sampler is outputting another time to see if there is something obvious that it is doing differently

winglian avatar May 29 '24 22:05 winglian

Oh gotcha, I'll look into it

dsesclei avatar May 29 '24 23:05 dsesclei