torchft icon indicating copy to clipboard operation
torchft copied to clipboard

[CheckpointServer] use streaming transfers

Open d4l3k opened this issue 1 year ago • 5 comments

The CheckpointServer currently uses torch.save/torch.load which requires allocating the entire buffer into memory. We want to instead use streaming transfers so we minimize the amount of CPU memory required.

It would also be nice to add checksums to these transfers to avoid any data corruption from the network.

Relevant existing code: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing.py#L72

The algorithm is described at: https://gist.github.com/d4l3k/b68094d649a076384967788c9b0a5f08

Existing tests: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing_test.py#L15

Overview of work:

  1. copy over the write_state_dict and read_state_dict implementations into checkpointing.py
  2. replace existing torch.save/torch.load with those
  3. add unit tests for write_state_dict/read_state_dict for all the different possible types of torch tensors (different data types, strided, offsets, scalars, nested structures, etc)
  4. optionally add in checksum to read/write_state_dict that uses zlib.crc32

d4l3k avatar Dec 12 '24 17:12 d4l3k

@Krishn1412 would you be interested in working on this?

d4l3k avatar Dec 12 '24 17:12 d4l3k

Sure @d4l3k, I'll work on this

Krishn1412 avatar Dec 13 '24 13:12 Krishn1412

Hey @d4l3k , can you take a look at this? https://github.com/pytorch-labs/torchft/pull/54

Krishn1412 avatar Dec 20 '24 15:12 Krishn1412

@d4l3k It seems that write_state_dict and read_state_dict won't work with DTensor. Please correct me if I'm wrong.

fegin avatar Dec 20 '24 17:12 fegin

@fegin yeah, that's a good point -- we should be able to support DTensor without too much trouble though

d4l3k avatar Dec 20 '24 21:12 d4l3k