DALI icon indicating copy to clipboard operation
DALI copied to clipboard

Problem with multigpu parallel External Source operator

Open KsawerySmoczynski opened this issue 3 years ago • 9 comments

So I was following the tutorials for parallel External Source operator and have chosen to setup py_start_method to "spawn". Unfortunately while running it with pytorch lightning I've only recieved errors similar to ones attached here 4gpus_error.log.

I'm not sure whether it can be attributed to DALI or pytorch-lightning. I've tried all of the pickling methods mentioned in tutorials but none resolved this issue. The example below uses DDP but the error occured also in 1gpu setup. Locally on my laptop it runs smoothly, but I cannot reproduce the setup with multigpu. In my normal setup I'm running it on GCP NVIDIA 4xA100 machine.

Self-contained example reproducing the issue for 4-gpus based machine:

import random

from pathlib import Path
from typing import Union, Callable, List, Optional, Dict
import numpy as np
from PIL import Image
import cv2

from nvidia import dali
from nvidia.dali import pipeline_def, Pipeline
from nvidia.dali import fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from nvidia.dali import pickling as dali_pickle
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
import torch
from torch import nn


@dali_pickle.pickle_by_value
def SegmentationCallback(
    imgs_list: List[str],
    depths_list: List[str],
    labels_list: List[str],
    batch_size: int,
    seed: int,
    shard_id: int = 0,
    num_shards: int = 1,
    **kwargs,
):
    indices = np.arange(len(imgs_list))
    np.random.shuffle(indices)

    dataset_size = len(imgs_list)
    shard_size = dataset_size // num_shards
    shard_offset = shard_size * shard_id
    n_iterations = shard_size // batch_size
    perm = None
    last_seen_epoch = None

    def callback(sample_info):
        nonlocal perm, last_seen_epoch
        if sample_info.iteration >= n_iterations:
            raise StopIteration
        if last_seen_epoch != sample_info.epoch_idx:
            last_seen_epoch = sample_info.epoch_idx
            perm = np.random.default_rng(seed=seed + sample_info.epoch_idx).permutation(dataset_size)

        index = perm[sample_info.idx_in_epoch + shard_offset]
        img = np.frombuffer(open(imgs_list[index], "rb").read(), dtype=np.uint8)
        dsm = cv2.imread(depths_list[index], cv2.IMREAD_ANYDEPTH)[..., np.newaxis]
        label = np.frombuffer(open(labels_list[index], "rb").read(), dtype=np.uint8)
        return img, dsm, label

    return callback


@pipeline_def
def EnhancedPipeline(
    data_iterator: Callable,
    precision: int = 32,
):

    use_gpu = Pipeline.current().device_id is not None
    decoder_device, device = ("mixed", "gpu") if use_gpu else ("cpu", "cpu")

    img_type = dali.types.FLOAT16 if precision == 16 else dali.types.FLOAT
    imgs, depth, labels = fn.external_source(
        source=data_iterator,
        num_outputs=3,
        batch=False,
        parallel=True,
        dtype=[dali.types.UINT8, dali.types.FLOAT, dali.types.UINT8],
    )
    if use_gpu:
        depth = depth.gpu()

    imgs = fn.decoders.image(imgs, device=decoder_device)
    labels = fn.decoders.image(labels, device=decoder_device, output_type=dali.types.GRAY)

    imgs /= 255

    labels = fn.cast(labels, dtype=dali.types.INT64)
    imgs = fn.cast(imgs, dtype=img_type)
    depth = fn.cast(depth, dtype=img_type)

    imgs = fn.cat(imgs, depth, axis=2)
    imgs, labels = fn.transpose(
        [imgs, labels],
        perm=[2, 0, 1],
    )

    return imgs, labels


class DALIGenericIteratorWrapper(DALIGenericIterator):
    """
    Class wrapping DALIGenericIterator in order to obtain the batches in the pytorches iterable format (input, label).
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __next__(self):
        out = super().__next__()
        out = [out[0][output_name] for output_name in self.output_map]
        return out


class Model(LightningModule):
    def __init__(self, num_classes=int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Conv2d(4, num_classes, 1)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        return self._step(batch)

    def validation_step(self, batch, batch_idx):
        return self._step(batch)

    def test_step(self, batch, batch_idx, dataloader_idx: Optional[int] = 0):
        return self._step(batch, dataloader_idx=dataloader_idx)

    def _step(self, batch, dataloader_idx=None):
        inputs, masks = batch
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, masks.squeeze(1))
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


class SegmentationDataModule(LightningDataModule):
    def __init__(
        self,
        imgs_list: str,
        depth_list: str,
        labels_list: str,
        batch_size: int,
        num_workers: int,
        gpus: Optional[Union[List[int], str, int]] = None,
        accelerator: str = None,
        precision: int = 32,
        seed: int = 42,
        **kwargs,
    ):
        super().__init__()
        self.imgs_list = imgs_list
        self.depth_list = depth_list
        self.labels_list = labels_list

        self.pipeline_config = {
            "num_threads": num_workers if num_workers != 0 else 1,
            "device_id": None if not gpus else 0,
            "seed": seed,
            "batch_size": batch_size,
            "precision": precision,
            "py_num_workers": num_workers if num_workers != 0 else 1,
            "py_start_method": "spawn",
        }
        self.accelerator = accelerator

    def setup(self, stage: Optional[str] = None):
        paths = {"imgs": self.imgs_list, "depths": self.depth_list, "labels": self.labels_list}
        self.train_dataset = {**paths}
        self.val_dataset = {**paths}
        self.test_dataset = {**paths}

    def train_dataloader(self):
        return self._get_iterator(
            self.train_dataset["imgs"],
            self.train_dataset["depths"],
            self.train_dataset["labels"],
            self.pipeline_config,
        )

    def val_dataloader(self):
        return self._get_iterator(
            self.val_dataset["imgs"],
            self.val_dataset["depths"],
            self.val_dataset["labels"],
            self.pipeline_config,
        )

    def test_dataloader(self):
        return self._get_iterator(
            self.test_dataset["imgs"],
            self.test_dataset["depths"],
            self.test_dataset["labels"],
            self.pipeline_config,
        )

    def _get_iterator(self, imgs_list: List, depths_list: List, labels_list: List, pipeline_config: Dict):
        assert len(imgs_list) == len(labels_list)
        if (
            self.accelerator
        ):  # has to be assigned for DDP, is available only after pytorch-lightning trainer initialization
            pipeline_config = {
                **pipeline_config,
                "device_id": self.trainer.local_rank,
            }

        data_callable = SegmentationCallback(
            imgs_list,
            depths_list,
            labels_list,
            shard_id=self.trainer.global_rank,
            num_shards=self.trainer.world_size,
            **pipeline_config,
        )
        pipe = EnhancedPipeline(data_callable, **pipeline_config)
        pipe.build()

        return DALIGenericIteratorWrapper(
            pipelines=pipe, output_map=["imgs", "labels"], size=len(imgs_list) // (self.trainer.world_size)
        )


if __name__ == "__main__":
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    img = np.zeros((512, 512, 3), dtype=np.uint8)
    img[256:, 256:, 0] = 255

    depth = np.zeros((512, 512))
    depth[256:, 256:] = 100

    label = np.zeros((512, 512), dtype=np.uint8)
    label[256:, 256:] = 1

    data_root = Path("imgs")
    data_root.mkdir(parents=True, exist_ok=True)
    Image.fromarray(img).save(data_root / "img.png")
    Image.fromarray(depth, mode="F").save(data_root / "depth.tif", compression="tiff_lzw")
    Image.fromarray(label).save(data_root / "label.png")

    imgs_list = [str(data_root / "img.png")] * 100
    depths_list = [str(data_root / "depth.tif")] * 100
    labels_list = [str(data_root / "label.png")] * 100

    batch_size = 4
    num_workers = 6
    gpus = 4
    accelerator = "ddp"
    precision = 16
    model = Model(2)
    datamodule = SegmentationDataModule(
        imgs_list, depths_list, labels_list, batch_size, num_workers, gpus, accelerator, precision
    )

    trainer_config = {"gpus": gpus, "max_epochs": 10, "accelerator": accelerator, "precision": precision}
    trainer = Trainer(**trainer_config)
    trainer.fit(model=model, datamodule=datamodule)

With this one solved, it would be the last of one from my odyssey of DALI issues, thanks a lot for your help!

KsawerySmoczynski avatar Jun 13 '22 13:06 KsawerySmoczynski

Hi @KsawerySmoczynski,

I'm not sure if you should recreate the DALI pipeline in every call to the *_dataloader method. Can you check this example, and create it only once in the setup method?

@stiepan - do you think that recreation of the pipeline can cause the below error:

    self._receive_chunk()
  File "/usr/local/lib/python3.8/dist-packages/nvidia/dali/_multiproc/pool.py", line 738, in _receive_chunk
    raise RuntimeError("Worker data receiving interrupted")
RuntimeError: Worker data receiving interrupted```

JanuszL avatar Jun 13 '22 19:06 JanuszL

Thanks for your response @JanuszL I've implemented it according to your suggestion and example but the error unfortunately persists.

KsawerySmoczynski avatar Jun 14 '22 14:06 KsawerySmoczynski

Hi @KsawerySmoczynski.

One thing that crossed my mind is the size of /dev/shm. External source workers rely heavily on shared memory while containers tend to limit its size. The error message you see could be due to worker process exiting when it cannot allocate enough shm to pass the samples. I am not sure if that is applicable in your case, but here's some stackoverlflow issue on how to increase /dev/shm size in GCP VM: https://stackoverflow.com/questions/66456581/gcc-vm-machine-allocate-more-dev-shm-memory.

stiepan avatar Jun 14 '22 15:06 stiepan

Bingo, that was probably it, thanks a lot @stiepan ! Now the problem is that after first epoch execution hangs, then it proceeds with one batch per epoch skipping validation phases. The auto_reset parameter of GenericIterator is set to true. The machine has 48 cores, I've set num_threads parameter to 3 and py_num_workers to 4 (per gpu). Here is the log.

As you can see there is a almost 10 minutes break between init of 1-st train epoch and batches being processed (line 98 and 99). Then the whole log is being written at once as I've inspected it during the training. Same situation occurs while running it on pod with 1gpu. Do you know what might be the issue?

KsawerySmoczynski avatar Jun 14 '22 18:06 KsawerySmoczynski

Hi @KsawerySmoczynski,

As you can see there is a almost 10 minutes break between init of 1-st train epoch and batches being processed (line 98 and 99). Then the whole log is being written at once as I've inspected it during the training.

I guess this might be some kind of python printing buffering. Can you try out suggestions from this StackOverflow thread?

Running your code sample I also see that you may experience an issue raised in https://github.com/PyTorchLightning/pytorch-lightning/issues/12956 we want to fix by https://github.com/NVIDIA/DALI/pull/3923.

JanuszL avatar Jun 14 '22 23:06 JanuszL

Thanks for suggestion @JanuszL

Before adding external source operator (using DALI readers) there was no problem with buffering. Using flag python -u (...) doesn't resolve this issue. In my setup I'm using MLFlow logger utility and the metrics logging also lags.

I've set auto_reset=False of DALIGenericIterator, added manual reset of iterator at each call to {train,validation,test}_dataloader methods and used reload_dataloaders_every_n_epochs=1 of pytorch-lightning trainer, but with no luck. I'm using pytorch-lightning 1.5.9 Only thing that changed were the printed warnings that DALI iterator doesn't support reseting while epoch is not finished.

I'm attaching log here

KsawerySmoczynski avatar Jun 15 '22 11:06 KsawerySmoczynski

Hi @KsawerySmoczynski,

I have tested PYTHONUNBUFFERED=1 and it helped me in the case I had. Maybe here there is a different cause of the problem. I run your repro (with auto_reset=True) and I don't see any logging lag. Does it reproduce on your side with this simple code as well?

JanuszL avatar Jun 15 '22 11:06 JanuszL

Thanks for your response @JanuszL

With simple code, when I multiplied the list with images paths by 100k the lag also occurred but was significantly shorter than in the case of real-life scenario. All in all, lag does not bother me, I thought that it might be a hint to what is happening.

Anyways, simple example works smoothly also on the pod. In my real-life scenario I've limited the number of images that the iterator will load in total to 100 per gpu and it runs again without a problem. But with all images paths passed to the callback the situation from previous logs occurs. I've also ran a training with torch dataloaders and it runs smoothly so none of the images is corrupted. Can loading imgs from different memory regions contribute at all to this situation as compared to simple code when we are repeatedly loading the same image? Also in my scenario the imgs are being read directly from NAS without a copy to the host machine.

limited_data.log all_data_error.log

KsawerySmoczynski avatar Jun 17 '22 07:06 KsawerySmoczynski

@stiepan - it seems that spawning ES processes with callbacks having a huge (regarding the size in MB) argument list is a slow process. This is probably caused by a pickler that needs to process MBs or data. Maybe you can try passing the file name list as a file and read in the callback or use shared memory to pass the list itself:

...
from multiprocessing import shared_memory
imgs_list = shared_memory.ShareableList([str(data_root / "img.png")] * 10000000)
depths_list = shared_memory.ShareableList([str(data_root / "depth.tif")] * 10000000)
labels_list = shared_memory.ShareableList([str(data_root / "label.png")] * 10000000)
...
data_callable = SegmentationCallback(
            imgs_list.shm.name,
            depths_list.shm.name,
            labels_list.shm.name,
            shard_id=self.trainer.global_rank,
            num_shards=self.trainer.world_size,
            **pipeline_config,
        )
...
@dali_pickle.pickle_by_value
def SegmentationCallback(
    imgs_list,
    depths_list,
    labels_list,
    batch_size: int,
    seed: int,
    shard_id: int = 0,
    num_shards: int = 1,
    **kwargs,
):
    from multiprocessing import shared_memory
    imgs_list = shared_memory.ShareableList(name=imgs_list)
    depths_list = shared_memory.ShareableList(name=depths_list)
    labels_list = shared_memory.ShareableList(name=labels_list)
...

JanuszL avatar Jun 17 '22 15:06 JanuszL