DALI icon indicating copy to clipboard operation
DALI copied to clipboard

How to use ExternalSourcePipeline for distributed training

Open pawopawo opened this issue 3 years ago • 4 comments

from nvidia.dali.plugin.pytorch import DALIGenericIterator
from nvidia.dali.plugin.pytorch import LastBatchPolicy

class ExternalInputIterator(object):
    def __init__(self, batch_size, device_id, num_gpus):
        self.images_dir = ""
        self.batch_size = batch_size
        with open(self.images_dir + "ILSVRC2012_img_val.txt", 'r') as f:
            self.files = [line.rstrip() for line in f if line is not '']
        # whole data set size
        self.data_set_len = len(self.files)

        # based on the device_id and total number of GPUs - world size
        # get proper shard
        self.files = self.files[self.data_set_len * device_id // num_gpus:
                                self.data_set_len * (device_id + 1) // num_gpus]
        self.n = len(self.files)


    def __iter__(self):
        self.i = 0
        shuffle(self.files)
        return self

    def __next__(self):
        batch = []
        labels = []
        jpeg_filenames = []
        if self.i >= self.n:
            self.__iter__()
            raise StopIteration

        for _ in range(self.batch_size):
            jpeg_filename, label = self.files[self.i % self.n].split(' ')

            batch.append(np.fromfile(self.images_dir + jpeg_filename, dtype = np.uint8))  # we can use numpy
            labels.append(torch.tensor([int(label)], dtype = torch.uint8)) # or PyTorch's native tensors

            jpeg_filenames.append(np.fromstring(jpeg_filename, dtype=np.uint8))


            self.i += 1

        return (batch, labels, jpeg_filenames)

    def __len__(self):
        return self.data_set_len

    next = __next__


def ExternalSourcePipeline(batch_size, num_threads, device_id, external_data):
    pipe = Pipeline(batch_size, num_threads, device_id)
    with pipe:
        jpegs, labels, jpeg_filenames = fn.external_source(source=external_data, num_outputs=3)

        images = fn.decoders.image(jpegs,
                                   device="mixed",
                                   output_type=types.RGB)
        images = fn.resize(images,
                           device="gpu",
                           size=224,
                           mode="not_smaller",
                           interp_type=types.INTERP_TRIANGULAR)
        mirror = False
        crop=224
        images = fn.crop_mirror_normalize(images.gpu(),
                                          dtype=types.FLOAT,
                                          output_layout="CHW",
                                          crop=(crop, crop),
                                          mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                          std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                          mirror=mirror)

        # images = fn.cast(images, dtype=types.UINT8)
        # labels = labels.gpu()
        # self.cast = ops.Cast(device="gpu", dtype=types.UINT8)

        pipe.set_outputs(images, labels, jpeg_filenames)
    return pipe


eii = ExternalInputIterator(batch_size, 0, 1)
pipe = ExternalSourcePipeline(batch_size=batch_size,
                              num_threads=2,
                              device_id=0,
                              external_data=eii)
pii = DALIGenericIterator(pipe,
                          ['data', 'label', 'path'],
                          last_batch_policy=LastBatchPolicy.PARTIAL)

for e in range(epochs):
    for i, data in enumerate(pii):
        print("epoch: {}, iter {}, real batch size: {}".format(e, i, len(data[0]["data"])))

        # print (len(data[0]["path"]), len(data[0]["label"]))
        # print ("".join([chr(item) for item in data[0]["path"][0]]), data[0]["label"])

    pii.reset()

pawopawo avatar Jul 22 '21 15:07 pawopawo

The shard_id is not found in the fn.external_source function

pawopawo avatar Jul 22 '21 15:07 pawopawo

@pawopawo Yes, it's not there. It's a limitation, but you can always make shard_id a member of the source - if you use an iterator, like in this example, you need to have multiple instances of it anyway (one per pipeline) - you can embed the shard_id in the iterator, so and make each instance traverse only one shard. If you use a callable as the source argument, you can make it stateful, for example by using a lambda - again, making it shard-aware.

mzient avatar Jul 22 '21 16:07 mzient

Sorry, I don’t understand too much. Can you help implement this part of the code? Thank you very much~

pawopawo avatar Jul 23 '21 07:07 pawopawo

Hi @pawopawo,

What @mzient wanted to say is to add shard_id as ExternalInputIterator constructor argument. Also, you should not confuse shard_id with device_id. In the case of single-node training they happen to be the same, but with multinode one, it is no longer true. device_id is the GPU id used in the single node. shard_id is the part of the data set each GPU (globally) should work on. So something like should do (I just renamed variables, the code looks good, but I haven't tested it):

from nvidia.dali.plugin.pytorch import DALIGenericIterator
from nvidia.dali.plugin.pytorch import LastBatchPolicy

class ExternalInputIterator(object):
    def __init__(self, batch_size, shard_id, world_size):
        self.images_dir = ""
        self.batch_size = batch_size
        with open(self.images_dir + "ILSVRC2012_img_val.txt", 'r') as f:
            self.files = [line.rstrip() for line in f if line is not '']
        # whole data set size
        self.data_set_len = len(self.files)

        # based on the shard_id and total number of GPUs - world size
        # get proper shard
        self.files = self.files[self.data_set_len * shard_id // word_size:
                                self.data_set_len * (shard_id + 1) // word_size]
        self.n = len(self.files)


    def __iter__(self):
        self.i = 0
        shuffle(self.files)
        return self

    def __next__(self):
        batch = []
        labels = []
        jpeg_filenames = []
        if self.i >= self.n:
            self.__iter__()
            raise StopIteration

        for _ in range(self.batch_size):
            jpeg_filename, label = self.files[self.i % self.n].split(' ')

            batch.append(np.fromfile(self.images_dir + jpeg_filename, dtype = np.uint8))  # we can use numpy
            labels.append(torch.tensor([int(label)], dtype = torch.uint8)) # or PyTorch's native tensors

            jpeg_filenames.append(np.fromstring(jpeg_filename, dtype=np.uint8))


            self.i += 1

        return (batch, labels, jpeg_filenames)

    def __len__(self):
        return self.data_set_len

    next = __next__


def ExternalSourcePipeline(batch_size, num_threads, device_id, external_data):
    pipe = Pipeline(batch_size, num_threads, device_id)
    with pipe:
        jpegs, labels, jpeg_filenames = fn.external_source(source=external_data, num_outputs=3)

        images = fn.decoders.image(jpegs,
                                   device="mixed",
                                   output_type=types.RGB)
        images = fn.resize(images,
                           device="gpu",
                           size=224,
                           mode="not_smaller",
                           interp_type=types.INTERP_TRIANGULAR)
        mirror = False
        crop=224
        images = fn.crop_mirror_normalize(images.gpu(),
                                          dtype=types.FLOAT,
                                          output_layout="CHW",
                                          crop=(crop, crop),
                                          mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                          std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                          mirror=mirror)

        # images = fn.cast(images, dtype=types.UINT8)
        # labels = labels.gpu()
        # self.cast = ops.Cast(device="gpu", dtype=types.UINT8)

        pipe.set_outputs(images, labels, jpeg_filenames)
    return pipe


eii = ExternalInputIterator(batch_size, shard_id=0, world_size=1)
pipe = ExternalSourcePipeline(batch_size=batch_size,
                              num_threads=2,
                              device_id=0,
                              external_data=eii)
pii = DALIGenericIterator(pipe,
                          ['data', 'label', 'path'],
                          last_batch_policy=LastBatchPolicy.PARTIAL)

for e in range(epochs):
    for i, data in enumerate(pii):
        print("epoch: {}, iter {}, real batch size: {}".format(e, i, len(data[0]["data"])))

        # print (len(data[0]["path"]), len(data[0]["label"]))
        # print ("".join([chr(item) for item in data[0]["path"][0]]), data[0]["label"])

    pii.reset()

JanuszL avatar Jul 26 '21 07:07 JanuszL