data
data copied to clipboard
Rescalability layer
Implements rescaling of checkpoints to different world sizes and numbers of workers. User specifies in advance the number of data partitions, and when saving/loading checkpoints with different total workers, stateful guarantees are maintained: seen data is not revisited until the next epoch.
Based off of the datasets in the corresponding IBM torchtitan PR, but with an adjusted rescaling and iteration mechanism to support greater flexibility and robustness (removes divisibility constraints from worker and shard counts, and guarantees only one open file per physical worker regardless of number of logical shards). Uses StatefulDataLoader and DCP to manage checkpointing from the master process. An epoch completion testing script is included for demo purposes. It is possible that the IBM datasets can be merged into the existing torchdata Nodes structure.
Changes
- Add IBM rescalable datasets and checkpointing functions to
torchdata/stateful_dataloader/ibm_rescalable.py - Add demo script and correctness check to
examples/ibm_rescaling/rescaling_demo.py
Thanks for the work, @daviswer! Some first-level comments:
- Let's create some unit tests based on the example demo. I think it makes sense for them to live in
test/stateful_dataloader. - Let's name the Python file after the main abstraction users will use from so it, so
ibm_rescalable.pyshould becomescalable_reader.py. - The name
_WrapperDatasetis really generic. From the code and your comments, I think a name closer to the capability it provides might be_NestedStatefulDataset. - The class
_ShardFileHandlershould probably be an abstract base class. And since we anticipate that others may end up creating their own shard handlers for other formats, we should probably consider a public API, so we should drop the leading_. We might also want to break it out into its own file, such asshard_handler.py. Then all future shard handlers would go in there.
Thanks @scotts , I made changes 2-4 and working on unit tests now. I'll note that _StatefulDataset and _NestedStatefulDataset largely represent legacy code, gluing things together until we decide we either want to merge this into Nodes, or use these to represent stateful datasets (in which case we'll need to rework them anyways with contracts/APIs/etc. per #1456)
Sharing some points that we discussed over the call.
-
The core work here is in the data access layer, nor particularly in the data loader. I imagine we can figure out a way to give end users (think PyTorch users with established Dataset class) a
RescalableDatasetwrapper which converts their existing Dataset into ones which can be rescaled if one decides to start and re-start a job with a different world size. TheScalableReaderis effectively that, although we should wonder if want to make the user give more inputs (like a custom file handler) or we can configure those inside the rescalable dataset wrapper.This can feed directly into a canonical
StatefulDataLoader, with{save, load}_distributed_state_dictfunctionality incorporated intoStatefulDataLoader's state_dict / load_state_dict methods as special cases forRescalableDataset. At this point I don't know how feasible that is (@daviswer brought up a good point whether we want to take a dependency on DCP, vs having a generic interface for any checkpointing API to work) but this seems like a simpler interface for users to onboard to. -
So far the implementation is solving for text-heavy AI workloads. We should also align on whether we want to extend the scope to include other styles, like for eg, a data row being an arbitrary
Dict[str, Any], typical map-style datasets, typical HuggingFace vision datasets, etc. -
I need to look at some internal data-access layer APIs to ensure we don't diverge too much.
@scotts @daviswer
Some thoughts on point 2 @divyanshk : we could definitely separate this out into a file-interface system, plus a separate rescaling-enabled interface between nested iterable Datasets and generic indexable structures. However, we may lose some capabilities in the process. In particular, the current approach is set up to a) handle sharding of indexable structures where the total number of indices is not known in advance (i.e. many shard files containing many documents, with limited access bandwidth) and b) ensure that no more than one file per device is open/active at a time, regardless of number of files/devices. If we abstract away notions of files/items behind a generic indexable interface, it becomes harder to maintain these guarantees.
It may be possible to still make that work but I'd have to think through the approach some more.
@scotts