[CheckpointServer] use streaming transfers
The CheckpointServer currently uses torch.save/torch.load which requires allocating the entire buffer into memory. We want to instead use streaming transfers so we minimize the amount of CPU memory required.
It would also be nice to add checksums to these transfers to avoid any data corruption from the network.
Relevant existing code: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing.py#L72
The algorithm is described at: https://gist.github.com/d4l3k/b68094d649a076384967788c9b0a5f08
Existing tests: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing_test.py#L15
Overview of work:
- copy over the write_state_dict and read_state_dict implementations into checkpointing.py
- replace existing torch.save/torch.load with those
- add unit tests for write_state_dict/read_state_dict for all the different possible types of torch tensors (different data types, strided, offsets, scalars, nested structures, etc)
- optionally add in checksum to read/write_state_dict that uses zlib.crc32
@Krishn1412 would you be interested in working on this?
Sure @d4l3k, I'll work on this
Hey @d4l3k , can you take a look at this? https://github.com/pytorch-labs/torchft/pull/54
@d4l3k It seems that write_state_dict and read_state_dict won't work with DTensor. Please correct me if I'm wrong.
@fegin yeah, that's a good point -- we should be able to support DTensor without too much trouble though