torchsnapshot
torchsnapshot copied to clipboard
Loading tensors in lists/dict that have not yet been instantiated
🚀 The feature
We'd like to be able to load tensors that are saved on disk but do not yet populate the destination module.
Motivation, pitch
Say we have a module that stores a list of tensors. During training, we increment that list.
If I'm using regular torch.save(state_dict)
. We will end up with a dictionary with a list of tensors, and we can just load it back where it belongs (as loading is not done in place).
With torchsnapshot, what I understand is that snapshot will look for my current state_dict, and repopulate it in-place. Hence, if my list of tensors is empty (which I expect to be when I load a checkpoint) all the tensors in the list will be discarded.
Example:
from torchsnapshot import StateDict, Snapshot
import torch
import os
def list_files(startpath):
for root, dirs, files in os.walk(startpath):
level = root.replace(startpath, '').count(os.sep)
indent = ' ' * 4 * (level)
print('{}{}/'.format(indent, os.path.basename(root)))
subindent = ' ' * 4 * (level + 1)
for f in files:
print('{}{}'.format(subindent, f))
class ClassWithSD:
def __init__(self):
self.obj = []
def state_dict(self):
return {"obj": self.obj}
def load_state_dict(self, sd):
self.obj = sd["obj"]
x = ClassWithSD()
# let's put 2 tensors in out list. We'd like to get them back when loading
x.obj.append(torch.tensor([1.0]))
x.obj.append(torch.tensor([2.0]))
app_state = {"x": x}
Snapshot.take(app_state=app_state, path="./")
snapshot = Snapshot(path="./")
y = ClassWithSD()
app_state = {"x": y}
snapshot.restore(app_state=app_state)
print(list_files("./0"))
print("content before take:", x.obj)
print("content after restore:", y.obj)
# with torch.save
torch.save(x.state_dict(), "torch_saved.pt")
y = ClassWithSD()
y.load_state_dict(torch.load("torch_saved.pt"))
print("torch.save:", y.obj)
Alternatives
No response
Additional context
Looking at this: https://github.com/pytorch/torchsnapshot/blob/4596fc6baf0fc9662cbfbc8d363cf115dc46d517/torchsnapshot/snapshot.py#L681-L736
I guess that what I would like is that if not all available_entries
are loaded, the remaining logical_path
s are still loaded in the state_dict that will be given to the stateful.load_state_dict(...)
at line 736.
@yifuwang was this fixed by https://github.com/pytorch/torchsnapshot/pull/104 ?