inferno icon indicating copy to clipboard operation
inferno copied to clipboard

Data transformations appear to apply to both train and val on cityscapes

Open pcicales opened this issue 4 years ago • 0 comments

  • inferno version: October 14 2020 (not installed)
  • Python version: 3.7.9
  • Operating System: Ubuntu 16.04

Description

I am working on adapting your cityscapes data loader. It appears that your transformations apply to both the train and validation set as documented below. This is a bug if that is the case.

What I Did

def make_transforms(image_shape, labels_as_onehot):
    # Make transforms
    image_transforms = Compose(PILImage2NumPyArray(),
                               NormalizeRange(),
                               RandomGammaCorrection(),
                               Normalize(mean=CITYSCAPES_MEAN, std=CITYSCAPES_STD))
    label_transforms = Compose(PILImage2NumPyArray(),
                               Project(projection=CITYSCAPES_CLASSES_TO_LABELS))
    joint_transforms = Compose(RandomSizedCrop(ratio_between=(0.6, 1.0),
                                               preserve_aspect_ratio=True),
                               # Scale raw image back to the original shape
                               Scale(output_image_shape=image_shape,
                                     interpolation_order=3, apply_to=[0]),
                               # Scale segmentation back to the original shape
                               # (without interpolation)
                               Scale(output_image_shape=image_shape,
                                     interpolation_order=0, apply_to=[1]),
                               RandomFlip(allow_ud_flips=False),
                               # Cast raw image to float
                               Cast('float', apply_to=[0]))
    if labels_as_onehot:
        # Applying Label2OneHot on the full label image makes it unnecessarily expensive,
        # because we're throwing it away with RandomSizedCrop and Scale. Tests show that it's
        # ~1 sec faster per image.
        joint_transforms \
            .add(Label2OneHot(num_classes=len(CITYSCAPES_LABEL_WEIGHTS), dtype='bool',
                              apply_to=[1])) \
            .add(Cast('float', apply_to=[1]))
    else:
        # Cast label image to long
        joint_transforms.add(Cast('long', apply_to=[1]))
    # Batchify
    joint_transforms.add(AsTorchBatch(2, add_channel_axis_if_necessary=False))
    # Return as kwargs
    return {'image_transform': image_transforms,
            'label_transform': label_transforms,
            'joint_transform': joint_transforms}


def get_cityscapes_loaders(root_directory, image_shape=(1024, 2048), labels_as_onehot=False,
                           include_coarse_dataset=False, read_from_zip_archive=True,
                           train_batch_size=1, validate_batch_size=1, num_workers=2):
    # Build datasets
    train_dataset = Cityscapes(root_directory, split='train',
                               read_from_zip_archive=read_from_zip_archive,
                               **make_transforms(image_shape, labels_as_onehot))
    if include_coarse_dataset:
        # Build coarse dataset
        coarse_dataset = Cityscapes(root_directory, split='train_extra',
                                    read_from_zip_archive=read_from_zip_archive,
                                    **make_transforms(image_shape, labels_as_onehot))
        # ... and concatenate with train_dataset
        train_dataset = Concatenate(coarse_dataset, train_dataset)

    validate_dataset = Cityscapes(root_directory, split='validate',
                                  read_from_zip_archive=read_from_zip_archive,
                                  **make_transforms(image_shape, labels_as_onehot))

    # Build loaders
    train_loader = data.DataLoader(train_dataset, batch_size=train_batch_size,
                                   shuffle=True, num_workers=num_workers, pin_memory=True)
    validate_loader = data.DataLoader(validate_dataset, batch_size=validate_batch_size,
                                      shuffle=True, num_workers=num_workers, pin_memory=True)
    return train_loader, validate_loader

pcicales avatar Oct 14 '20 19:10 pcicales