data icon indicating copy to clipboard operation
data copied to clipboard

Discussion: DCP APIs and broader contracts for rescalability

Open daviswer opened this issue 9 months ago • 2 comments

After much discussion, it was decided that the best approach to implementing rescalability would be to implement rescaling in the base file reader, in order to maintain low overhead and avoid proliferation of logical shard objects (see #1372 , #1455, torchtitan PR). However, this approach necessitates that all nodes above the base node become rescaling-aware: we must decide what behaviors to support and how to make specifying these behaviors friendly to the user.

I have identified four behaviors that I believe a fully capable rescalable pipeline should support, with some correspondence to the existing placement behaviors of DTensors:

  1. Drop on rescale. Certain values, such as scalars and RNG states, cannot be repartitioned and it makes no sense to try. These values should be dropped when rescaling but kept otherwise.
  2. Sharded save, sharded load. Large buffers (for example, a local shuffling buffer) can be pooled into a single DTensor, which is then resharded over a new number of workers when rescaling. DCP is largely built around supporting this particular behavior, but note that we must now handle cases where the number of workers may not divide the length of the buffer evenly, and we also may not know the length of the buffer in advance.
  3. Replicated values. This encompasses any expensive metadata objects that we may want to construct (slowly) once, but load from checkpoint afterwards. These values would ideally be saved from rank 0 only, but loaded back to all workers. DCP supports this behavior for non-DTensor objects.
  4. Sharded save, global load. Any state that cannot be resharded simply via (2), such as logical shard state, which must first be accumulated/divided into global pools of visited vs unvisited shards. Local values are saved from each rank, but accumulated globally on load. DCP supports this behavior for non-DTensor objects, by assigning a unique rank-based key for all such objects and recompiling them manually on load.

Note that while the above 4 behaviors raise some questions on DCP support, the larger question revolves around how we want to expose these options to users and/or incorporate them into existing Datasets or Nodes.

daviswer avatar Feb 25 '25 23:02 daviswer

I thought through this a bit more and I think I have a framework for handling these behaviors with minimal demands on DCP support changes. Additionally, it also supports checkpoint saving/loading with or without DCP by shunting all rescaling logic into DCP-specific save/load functions.

All state values carry an internal tag (class, attribute, prepended key name, etc.) signifying their rescaling behavior. If not saving/loading with DCP, then rescaling is disabled, these values are ignored, and state_dict and load_state_dict emit and consume the variables as normal, requiring no adjustments from the user. Four values are possible: Drop, Reshard, Replicate, and Custom.

  • Drop: default tag. Values are saved/loaded normally, unless rescaling, in which case they are skipped. This encompasses things like RNG state, which makes no sense to try and rescale.
  • Reshard: Encompasses things like buffers of items, which should be redistributed over the new set of workers when rescaling. Values must be torch tensors, and represent a single DTensor with sharding on dim 0. When rescaling, DCP handles the resharding and every worker gets its new slice of the DTensor values.
  • Replicate: For items that are duplicated across ranks, but should ideally only be stored once (i.e. dataset-level metadata).
  • Custom: For when more complex custom rescaling behavior is needed. Custom values are loaded back from checkpoint as either a single item, when not rescaling, or as a list of all the items across all the prior workers, when rescaling. If a list of items is given, a user-specified function is called that must return the given rank's new shard, constructed from the provided global state list.

This now allows us to shunt all rescaling logic into DCP-specific save_distributed/load_distributed functions (location TBD - inside stateful_dataloader, inside the training code, provided by the user, etc.). save_distributed handles the four desired behaviors as follows: Drop variables are compiled across all workers on the same device but otherwise saved as normal; Reshard variables are concatenated on dim 0 across all workers on the same device, then wrapped in a DTensor; Replicate variables are dropped entirely except for on rank 0, and Custom variables have their worker's rank prepended to their state dict key.

load_distributed first determines if rescaling is occurring based on current worldsize and number of checkpoint shard files. Drop variables are loaded if these values match. DCP handles the remaining behaviors by automatically: resharding all DTensors (Reshard variables), replicating non-DTensors (Replicate variables) across all ranks, and gathering the full set of Custom variables from across ranks, onto each rank, which load_distributed then compiles into a list. If rescaling, the list is passed back into the state dictionary, otherwise just the corresponding shard for the given rank. Then load_state_dict proceeds as normal - any custom rescaling of Custom variables can happen either in this function, or inside the corresponding dataset layer.

Upsides of this approach are that it 1) enables auto-handling for a wide range of possible behaviors, simply by marking class variables with desired tags, 2) imposes minimal type restrictions on state variables (Reshard must be a torch.Tensor, but all others can be any type), 3) becomes invisible for anyone who doesn't want or need rescaling (either skip save_distributed/load_distributed, or just don't tag any state variables), and 4) simplifies utilities like the proposed ScalableReader (#1455) by shunting the rescaling logic out of the layer and into the distributed save/load functions (and any custom rescaling functions for Custom variables).

Drawback is that if we want to scale between arbitrary numbers of workers, Reshard variables will require DCP to gracefully handle DTensors where world size does not cleanly divide size(dim 0). Worse, if we produce a checkpoint with resulting off-by-one local shards, and then rescale again, we'll need DCP to handle the above case plus uneven sizes of local shards within each file. I don't think this is currently supported?

daviswer avatar Mar 27 '25 01:03 daviswer

Thanks for the thorough review of the topic @daviswer .

Here are my thoughts on the topic after reviewing your text.

i\ Supporting (1) , (2) and (3) is possible within the DCP stack. Does not mean we support some of the capabilities right now in our APIs ( e.g., query api, ability to describe a distributed tensor load that does not restricted by DTensor's strict equal shard constraints etc). But, they seem like fair asks and we have seen similar use cases.

ii\ My concern is with custom re-sharding fields, and the fact that ( #1455) uses it, in addition to (1), (2) and (3). If we just treat DCP as a storage component ( KVstore or similar), then we can have a full custom ( collapsing (1), (2) and (3) in to (4) ) re-scaling/sharing) solution that uses DCP only for storage (saved as serialized bytesIO).

iii\ How big is the data in (2) -- buffers of items?. Pros of still using DCP's re-sharding layer includes;

  • Efficiently handle large object re-sharding ( in this case DTensor encoded state).
  • Easy handling of replicated tensors ( users don't have to write special code to dedup during save)

iv\ I like the idea of save_distributed/load_distributed util functions that encapsulates the re-scaling logic. For torch.titan integration it makes sense to use DCP as the storage layer ( makes things consistent for checkpointing). However, I don't have the same clarity when it comes to using DCP's re-scaling capabilities during the integration as the current PR includes state-management re-scaling that is not supported by DCP's internal ( trainer state oriented right now) re-scaling logic.

pradeepfn avatar Apr 21 '25 13:04 pradeepfn