torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Fast dataset resume

Open mariosasko opened this issue 8 months ago • 5 comments

This PR makes resuming dataset iteration from a checkpoint fast again.

This performance regression comes from https://github.com/pytorch/torchtitan/pull/838. In that PR, .skip is removed for both map-style and iterable-style datasets for correctness reasons. However, .skip works as expected for map-style datasets, so the change can be reverted for that case. On the other hand, for iterable-style datasets, calling .skip after split_dataset_by_node splits the number of elements to skip across the ranks (e.g. calling .skip(10) after split_dataset_by_node(<rank>, 2) effectively skips 5 (10 // 2 = 5) elements on each rank), which isn'r what we want/expect, so removing .skip was justified there. Still, we can make the whole thing much faster using the state_dict API for iterable-style datasets, which avoids re-iterating past shards/files when resuming.

mariosasko avatar Apr 09 '25 19:04 mariosasko

Can you we have an accuracy verification for this PR? I believe llama3 8B can reproduce the loss issue if the dataset doesn't resume correctly.

fegin avatar Apr 14 '25 17:04 fegin

Sorry for the delay, I addressed the comments and made the test much more robust.

I'm a bit confused by your comment. It sounds that the behavior of skip for IterableDataset is oblivious of whether it has gone through split_dataset_by_node or not, which is not intuitive?

It's hard to explain because the datasets logic seems buggy (e.g. this test in datasets should fail with correct parentheses). It should be easier to understand with a toy example. What surprised me is that changing the number of the data shards results in a completely different behaviour.

Can you we have an accuracy verification for this PR? I believe llama3 8B can reproduce the loss issue if the dataset doesn't resume correctly.

I don't have access to A100 / H100 GPUs right now, so it would be great if someone else could do the run. I improved the test significantly (e.g. now it re-loops the test datasets), so I'm not sure if this is really needed, though.

mariosasko avatar Apr 22 '25 00:04 mariosasko

Can you also fix the linter error and integration test error? I will try if I can verify with llama3.

fegin avatar Apr 24 '25 06:04 fegin

Unfortunately, one more bug needs to be fixed, this time directly in datasets ...

EDIT: Reported in datasets: https://github.com/huggingface/datasets/issues/7538

mariosasko avatar Apr 27 '25 18:04 mariosasko

Okay, I think this is ready for the verification test.

mariosasko avatar May 08 '25 09:05 mariosasko