torchtitan
torchtitan copied to clipboard
Use stateful dataloader to checkpoint data iteration order and token buffer
Summary:
Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426.
Also make sure the dataloader state has a different key per rank.
Test Plan:
Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run.
Reviewers: @tianyu-l
Subscribers: @andrewkho
Tasks:
Tags:
Looks like adding the index url (for torchdata) is causing other dependencies to not get installed. Will figure out how to fix this
@gokulavasan, @tianyu-l https://github.com/pytorch/pytorch/pull/125335 and https://github.com/pytorch/pytorch/pull/125334 should unblock this PR.
I've been testing this out, and ran into an issue with resuming from a checkpoint. I suspect it's because of how StatefulDataLoader
handles the state dict: https://github.com/pytorch/data/blob/11e16da61d7f5f587627c75e99ea664efef3e0f8/torchdata/stateful_dataloader/stateful_dataloader.py#L249
That is, a freshly initialized StatefulDataLoader
does not have a state dict to load into? I'm not very familiar with how DCP
works, so please correct me if it's wrong.
Edit: investigated a bit further, and indeed I get that state_dict
for the data loader in DCP.load()
is for example '0': {}
, which causes it to be discarded by DefaultLoadPlanner.set_up_planner
.
@rlrs Would it be possible to test it after my latest commit (b9b045d)? I missed adding that part.
@rlrs Would it be possible to test it after my latest commit (b9b045d)? I missed adding that part.
I had already added that in my version. I can't get it to load the state_dict, unless I first call iter(dataloader)
so that self._iterator is not None
.
If I call iter
before DCP.load
, and then set self._first_iter = True
in HuggingFaceDataset.load_state_dict
, everything seems to work!
@tianyu-l Addressed PR comments (thank you!), added unit test, and made changes to the github workflows to allow running those unit tests. Let me know if the changes look okay. Regarding the move to DTensor, I think this requires analysis of what the benefits are (especially for storing unstructured state dict of dataloader).
If it is purely to reduce the replication of state across tensor/pipeline parallel groups, I think we can store the dataloader state just for the dp worker ranks (by using key as the dp_rank_id) and load it back instead of storing it for all global ranks. For now, with just the text tokens, this might not even be necessary as the state is not that big. Let me know how you would like to proceed.
@rlrs Thank you for your great analysis here (https://github.com/pytorch/torchtitan/pull/279#issuecomment-2104797493).
Helped us narrow down the issue which basically boiled down to in-place loading of checkpoint of DCP. StatefulDataLoader doesn't currently return no state if dataloader iterator is not created while DCP expects the module it let it know what the keys the module is expecting.
In order to get around this, I serialized the state of the dataloader and in this case there is only one key to load that is communicated by the DataLoaderWrapper to DCP - "<rank_id>".
Hi ! I'm Quentin from HF :)
FYI we just added state_dict() and load_state_dict() in datasets.IterableDataset
, which can resume iteration faster than just skipping samples !