torchsnapshot icon indicating copy to clipboard operation
torchsnapshot copied to clipboard

Loading tensors in lists/dict that have not yet been instantiated

Open vmoens opened this issue 2 years ago • 1 comments

🚀 The feature

We'd like to be able to load tensors that are saved on disk but do not yet populate the destination module.

Motivation, pitch

Say we have a module that stores a list of tensors. During training, we increment that list.

If I'm using regular torch.save(state_dict). We will end up with a dictionary with a list of tensors, and we can just load it back where it belongs (as loading is not done in place).

With torchsnapshot, what I understand is that snapshot will look for my current state_dict, and repopulate it in-place. Hence, if my list of tensors is empty (which I expect to be when I load a checkpoint) all the tensors in the list will be discarded.

Example:


from torchsnapshot import StateDict, Snapshot
import torch
import os

def list_files(startpath):
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print('{}{}/'.format(indent, os.path.basename(root)))
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print('{}{}'.format(subindent, f))

class ClassWithSD:
    def __init__(self):
        self.obj = []
    def state_dict(self):
        return {"obj": self.obj}
    def load_state_dict(self, sd):
        self.obj = sd["obj"]


x = ClassWithSD()

# let's put 2 tensors in out list. We'd like to get them back when loading
x.obj.append(torch.tensor([1.0]))
x.obj.append(torch.tensor([2.0]))

app_state = {"x": x}
Snapshot.take(app_state=app_state, path="./")


snapshot = Snapshot(path="./")
y = ClassWithSD()
app_state = {"x": y}
snapshot.restore(app_state=app_state)

print(list_files("./0"))
print("content before take:", x.obj)
print("content after restore:", y.obj)

# with torch.save

torch.save(x.state_dict(), "torch_saved.pt")
y = ClassWithSD()
y.load_state_dict(torch.load("torch_saved.pt"))
print("torch.save:", y.obj)

Alternatives

No response

Additional context

Looking at this: https://github.com/pytorch/torchsnapshot/blob/4596fc6baf0fc9662cbfbc8d363cf115dc46d517/torchsnapshot/snapshot.py#L681-L736

I guess that what I would like is that if not all available_entries are loaded, the remaining logical_paths are still loaded in the state_dict that will be given to the stateful.load_state_dict(...) at line 736.

vmoens avatar Oct 16 '22 05:10 vmoens

@yifuwang was this fixed by https://github.com/pytorch/torchsnapshot/pull/104 ?

ananthsub avatar Oct 21 '22 18:10 ananthsub