Fast dataset resume
This PR makes resuming dataset iteration from a checkpoint fast again.
This performance regression comes from https://github.com/pytorch/torchtitan/pull/838. In that PR, .skip is removed for both map-style and iterable-style datasets for correctness reasons. However, .skip works as expected for map-style datasets, so the change can be reverted for that case. On the other hand, for iterable-style datasets, calling .skip after split_dataset_by_node splits the number of elements to skip across the ranks (e.g. calling .skip(10) after split_dataset_by_node(<rank>, 2) effectively skips 5 (10 // 2 = 5) elements on each rank), which isn'r what we want/expect, so removing .skip was justified there. Still, we can make the whole thing much faster using the state_dict API for iterable-style datasets, which avoids re-iterating past shards/files when resuming.
Can you we have an accuracy verification for this PR? I believe llama3 8B can reproduce the loss issue if the dataset doesn't resume correctly.
Sorry for the delay, I addressed the comments and made the test much more robust.
I'm a bit confused by your comment. It sounds that the behavior of skip for IterableDataset is oblivious of whether it has gone through split_dataset_by_node or not, which is not intuitive?
It's hard to explain because the datasets logic seems buggy (e.g. this test in datasets should fail with correct parentheses). It should be easier to understand with a toy example. What surprised me is that changing the number of the data shards results in a completely different behaviour.
Can you we have an accuracy verification for this PR? I believe llama3 8B can reproduce the loss issue if the dataset doesn't resume correctly.
I don't have access to A100 / H100 GPUs right now, so it would be great if someone else could do the run. I improved the test significantly (e.g. now it re-loops the test datasets), so I'm not sure if this is really needed, though.
Can you also fix the linter error and integration test error? I will try if I can verify with llama3.
Unfortunately, one more bug needs to be fixed, this time directly in datasets ...
EDIT:
Reported in datasets: https://github.com/huggingface/datasets/issues/7538
Okay, I think this is ready for the verification test.