data icon indicating copy to clipboard operation
data copied to clipboard

Rescalability layer

Open daviswer opened this issue 9 months ago • 4 comments

Implements rescaling of checkpoints to different world sizes and numbers of workers. User specifies in advance the number of data partitions, and when saving/loading checkpoints with different total workers, stateful guarantees are maintained: seen data is not revisited until the next epoch.

Based off of the datasets in the corresponding IBM torchtitan PR, but with an adjusted rescaling and iteration mechanism to support greater flexibility and robustness (removes divisibility constraints from worker and shard counts, and guarantees only one open file per physical worker regardless of number of logical shards). Uses StatefulDataLoader and DCP to manage checkpointing from the master process. An epoch completion testing script is included for demo purposes. It is possible that the IBM datasets can be merged into the existing torchdata Nodes structure.

Changes

  • Add IBM rescalable datasets and checkpointing functions to torchdata/stateful_dataloader/ibm_rescalable.py
  • Add demo script and correctness check to examples/ibm_rescaling/rescaling_demo.py

daviswer avatar Feb 25 '25 22:02 daviswer

Thanks for the work, @daviswer! Some first-level comments:

  1. Let's create some unit tests based on the example demo. I think it makes sense for them to live in test/stateful_dataloader.
  2. Let's name the Python file after the main abstraction users will use from so it, so ibm_rescalable.py should become scalable_reader.py.
  3. The name _WrapperDataset is really generic. From the code and your comments, I think a name closer to the capability it provides might be _NestedStatefulDataset.
  4. The class _ShardFileHandler should probably be an abstract base class. And since we anticipate that others may end up creating their own shard handlers for other formats, we should probably consider a public API, so we should drop the leading _. We might also want to break it out into its own file, such as shard_handler.py. Then all future shard handlers would go in there.

scotts avatar Feb 28 '25 20:02 scotts

Thanks @scotts , I made changes 2-4 and working on unit tests now. I'll note that _StatefulDataset and _NestedStatefulDataset largely represent legacy code, gluing things together until we decide we either want to merge this into Nodes, or use these to represent stateful datasets (in which case we'll need to rework them anyways with contracts/APIs/etc. per #1456)

daviswer avatar Mar 05 '25 19:03 daviswer

Sharing some points that we discussed over the call.

  1. The core work here is in the data access layer, nor particularly in the data loader. I imagine we can figure out a way to give end users (think PyTorch users with established Dataset class) a RescalableDataset wrapper which converts their existing Dataset into ones which can be rescaled if one decides to start and re-start a job with a different world size. The ScalableReader is effectively that, although we should wonder if want to make the user give more inputs (like a custom file handler) or we can configure those inside the rescalable dataset wrapper.

    This can feed directly into a canonical StatefulDataLoader, with {save, load}_distributed_state_dict functionality incorporated into StatefulDataLoader's state_dict / load_state_dict methods as special cases for RescalableDataset. At this point I don't know how feasible that is (@daviswer brought up a good point whether we want to take a dependency on DCP, vs having a generic interface for any checkpointing API to work) but this seems like a simpler interface for users to onboard to.

  2. So far the implementation is solving for text-heavy AI workloads. We should also align on whether we want to extend the scope to include other styles, like for eg, a data row being an arbitrary Dict[str, Any], typical map-style datasets, typical HuggingFace vision datasets, etc.

  3. I need to look at some internal data-access layer APIs to ensure we don't diverge too much.

@scotts @daviswer

divyanshk avatar Mar 14 '25 20:03 divyanshk

Some thoughts on point 2 @divyanshk : we could definitely separate this out into a file-interface system, plus a separate rescaling-enabled interface between nested iterable Datasets and generic indexable structures. However, we may lose some capabilities in the process. In particular, the current approach is set up to a) handle sharding of indexable structures where the total number of indices is not known in advance (i.e. many shard files containing many documents, with limited access bandwidth) and b) ensure that no more than one file per device is open/active at a time, regardless of number of files/devices. If we abstract away notions of files/items behind a generic indexable interface, it becomes harder to maintain these guarantees.

It may be possible to still make that work but I'd have to think through the approach some more.

@scotts

daviswer avatar Mar 26 '25 21:03 daviswer