torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Use stateful dataloader to checkpoint data iteration order and token buffer

Open gokulavasan opened this issue 10 months ago • 7 comments

Summary:

Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426.

Also make sure the dataloader state has a different key per rank.

Test Plan:

Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run.

Reviewers: @tianyu-l

Subscribers: @andrewkho

Tasks:

Tags:

gokulavasan avatar Apr 26 '24 17:04 gokulavasan

Looks like adding the index url (for torchdata) is causing other dependencies to not get installed. Will figure out how to fix this

gokulavasan avatar Apr 26 '24 18:04 gokulavasan

@gokulavasan, @tianyu-l https://github.com/pytorch/pytorch/pull/125335 and https://github.com/pytorch/pytorch/pull/125334 should unblock this PR.

fegin avatar May 01 '24 21:05 fegin

I've been testing this out, and ran into an issue with resuming from a checkpoint. I suspect it's because of how StatefulDataLoader handles the state dict: https://github.com/pytorch/data/blob/11e16da61d7f5f587627c75e99ea664efef3e0f8/torchdata/stateful_dataloader/stateful_dataloader.py#L249

That is, a freshly initialized StatefulDataLoader does not have a state dict to load into? I'm not very familiar with how DCP works, so please correct me if it's wrong.

Edit: investigated a bit further, and indeed I get that state_dict for the data loader in DCP.load() is for example '0': {}, which causes it to be discarded by DefaultLoadPlanner.set_up_planner.

rlrs avatar May 10 '24 09:05 rlrs

@rlrs Would it be possible to test it after my latest commit (b9b045d)? I missed adding that part.

gokulavasan avatar May 10 '24 14:05 gokulavasan

@rlrs Would it be possible to test it after my latest commit (b9b045d)? I missed adding that part.

I had already added that in my version. I can't get it to load the state_dict, unless I first call iter(dataloader) so that self._iterator is not None.

If I call iter before DCP.load, and then set self._first_iter = True in HuggingFaceDataset.load_state_dict, everything seems to work!

rlrs avatar May 10 '24 15:05 rlrs

@tianyu-l Addressed PR comments (thank you!), added unit test, and made changes to the github workflows to allow running those unit tests. Let me know if the changes look okay. Regarding the move to DTensor, I think this requires analysis of what the benefits are (especially for storing unstructured state dict of dataloader).

If it is purely to reduce the replication of state across tensor/pipeline parallel groups, I think we can store the dataloader state just for the dp worker ranks (by using key as the dp_rank_id) and load it back instead of storing it for all global ranks. For now, with just the text tokens, this might not even be necessary as the state is not that big. Let me know how you would like to proceed.

gokulavasan avatar May 17 '24 02:05 gokulavasan

@rlrs Thank you for your great analysis here (https://github.com/pytorch/torchtitan/pull/279#issuecomment-2104797493).

Helped us narrow down the issue which basically boiled down to in-place loading of checkpoint of DCP. StatefulDataLoader doesn't currently return no state if dataloader iterator is not created while DCP expects the module it let it know what the keys the module is expecting.

In order to get around this, I serialized the state of the dataloader and in this case there is only one key to load that is communicated by the DataLoaderWrapper to DCP - "<rank_id>".

gokulavasan avatar May 17 '24 02:05 gokulavasan

Hi ! I'm Quentin from HF :) FYI we just added state_dict() and load_state_dict() in datasets.IterableDataset, which can resume iteration faster than just skipping samples !

lhoestq avatar Jul 13 '24 10:07 lhoestq