data icon indicating copy to clipboard operation
data copied to clipboard

what is the right way to serialize DataLoader2 so that pipeline with shuffle can resume from the right place?

Open zhengwy888 opened this issue 1 year ago • 2 comments

🐛 Describe the bug

I tried all these versions, the only version that worked was the last one, but it's too hacky. Is there a better way?

    dp = IterableWrapper(list(range(20)))
    dp = dp.shuffle()
    items = []
    rs = InProcessReadingService()
    dl = DataLoader2(dp, reading_service=rs)
    iter1 = iter(dl)
    for _ in range(4):
        next(iter1)

    # 16 elements left in dl
    state = dl.state_dict()
    dl2 = DataLoader2.from_state(state, reading_service=rs)
    # assert len(list(dl2)) == 20 - 4  # got 20

    dp2 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    # assert len(list(dp2)) == 20 - 4 # got 20

    dp3 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    _simple_graph_snapshot_restoration(dp3, dp3._number_of_samples_yielded)
    ret3 = list(dp3)
    assert len(ret3) == 20 - 4
    # but content is not the same

    dl4 = DataLoader2.from_state(state, reading_service=rs)
    _simple_graph_snapshot_restoration(dl4.datapipe, dl.datapipe._number_of_samples_yielded)
    ret4 = list(dl4)
    assert len(ret4) == 20 - 4
    # but content is not the same

    dp5 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    pipes = get_all_pipes(dp5)
    for pipe in pipes:
        if isinstance(pipe, ShufflerIterDataPipe):
            buffer_cache = pipe._buffer[:]
            assert len(buffer_cache) == 20 - 4
            rng_state = pipe._rng.getstate()
    _simple_graph_snapshot_restoration(dp5, dl.datapipe._number_of_samples_yielded)
    dp5._buffer = buffer_cache[:]
    dp5._rng.setstate(rng_state)
    it5 = iter(dp5)
    ret5 = list(it5)
    assert len(ret5) == 20 - 4

    expected = list(iter1)
    # ret5 is the only method that worked
    # assert ret3 == expected
    # assert ret4 == expected
    assert ret5 == expected

Versions

PyTorch version: 2.0.0a0+gite9ebda2
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 12.0.1 (https://github.com/conda-forge/clangdev-feedstock d44358f44aef33e9fa7c5f93e2481ee8f1a04ab6)
CMake version: version 3.19.1
Libc version: glibc-2.31

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-64-generic-x86_64-with-glibc2.10
Is CUDA available: False
CUDA runtime version: 12.0.140
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] mypy-protobuf==3.3.0
[pip3] numpy==1.23.5
[pip3] pytorch3d==0.6.2
[pip3] torch==2.0.1+1684801906.cuda120.cudnn891.nccl218.ap
[pip3] torch-mlir==1684442443
[pip3] torch-scatter==2.1.0
[pip3] torch-tb-profiler==0.4.1
[pip3] torchdata==0.7.0.dev20230601
[pip3] torchfile==0.1.0
[pip3] torchvision==0.15.1a0+42759b1
[conda] magma-cuda121             2.6.1                         1    pytorch
[conda] mkl                       2020.4             h726a3e6_304    conda-forge
[conda] mkl-include               2023.1.0         h84fe81f_48680    conda-forge
[conda] numpy                     1.23.5           py38h7042d01_0    conda-forge
[conda] pytorch3d                 0.6.2                    pypi_0    pypi
[conda] torch                     2.0.1+1684801906.cuda120.cudnn891.nccl218.ap          pypi_0    pypi
[conda] torch-mlir                1684442443               pypi_0    pypi
[conda] torch-scatter             2.1.0                    pypi_0    pypi
[conda] torch-tb-profiler         0.4.1                    pypi_0    pypi
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchvision               0.15.1a0+42759b1          pypi_0    pypi

zhengwy888 avatar Jun 02 '23 06:06 zhengwy888