tensordict icon indicating copy to clipboard operation
tensordict copied to clipboard

[BUG] MemoryMappedTensor Loading

Open suessmann opened this issue 1 year ago • 0 comments

Describe the bug

I collected a memmap tensordict similar to the guide provided [1] on the cluster in a jupyter notebook. When loading the same memmap on my local machine (with TensorDict.load_memmap(path), I get the error RuntimeError: Could not find name <class '__main__.ImageNetData'>, since I'm not loading the memmap from in __main__. I suspect the issue is in meta.json file of the memmap, where the type is specified as <class '__main__.ImageNetData'>, but I do not run load_memmap(path) from __main__.

To Reproduce

Follow [1] and save the path to memmap. Then create main.py:

from data import Dataset

def main(path):
    data = Dataset(path)

if __name__ == '__main__':
    main('path/to/memmap')

in data.py

from tensordict import MemoryMappedTensor, tensorclass, TensorDict

@tensorclass
class ImageNetData:
    images: torch.Tensor
    targets: torch.Tensor

class Dataset:
    def __init__(path):
        self.data = TensorDict.load_memmap(path)

and you will get

RuntimeError: Could not find name <class '__main__.ImageNetData'>

Expected behavior

A slick load of memmap.

System info

import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)

0.5.0 1.26.4 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0] linux 2.4.1+cu121

Reason and Possible fixes

I manually changed meta.json to

{"_type":"<class 'data.ImageNetData'>"}

but it is not the most consistent way. There is also an option to make use of snapshots, but from the example [2] I see that to load a snapshot, one needs to initialize memmap each time, which is super time consuming in my case (my data is >500GB of size).

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

[1] https://pytorch.org/tensordict/main/tutorials/tensorclass_imagenet.html [2] https://github.com/pytorch/tensordict/blob/16595186c5b8fa8c9a735bf3639a36977f8a63e1/benchmarks/distributed/dataloading.py#L139

suessmann avatar Oct 21 '24 12:10 suessmann