pytorch-enhance icon indicating copy to clipboard operation
pytorch-enhance copied to clipboard

Modified dataset to load from a local dir

Open robmarkcole opened this issue 3 years ago • 0 comments

Hi @isaaccorley I modified the dataset to load from a local dir (in my case a mounted google drive) - might be one to add to the wiki if not to add to the codebase

class BaseDataset(torch.utils.data.Dataset):
    """Base Super Resolution Dataset Class
    """
    color_space: str = "RGB"
    lr_transform: T.Compose = None
    hr_transform: T.Compose = None

    def get_lr_transforms(self):
        """Returns HR to LR image transformations
        """
        return Compose([
            Resize(size=(
                    self.image_size//self.scale_factor,
                    self.image_size//self.scale_factor
                ),
                interpolation=Image.BICUBIC
            ),
            ToTensor(),
        ])

    def get_hr_transforms(self):
        """Returns HR image transformations
        """
        return Compose([
            Resize((self.image_size, self.image_size), Image.BICUBIC),
            ToTensor(),
        ])

    def get_files(self, data_dir: str) -> List[str]:
        """Returns  a list of valid image files in a directory
        Parameters
        ----------
        root_dir : str
            Path to directory of images.
        Returns
        -------
        List[str]
            List of valid images in `root_dir` directory.
        """
        return glob.glob(data_dir + '*.jpg')

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns  a tuple of and lr and hr torch tensors
        Parameters
        ----------
        idx : int
            Index value to index the list of images
        Returns
        -------
        lr: torch.Tensor
            Low Resolution transformed indexed image.
        hr: torch.Tensor
            High Resolution transformed indexed image.
        """
        lr = self.load_img(self.file_names[idx])
        hr = lr.copy()
        if self.lr_transform:
            lr = self.lr_transform(lr)
        if self.hr_transform:
            hr = self.hr_transform(hr)

        return lr, hr

    def __len__(self) -> int:
        """Return number of images in dataset
        Returns
        -------
        int
            Number of images in dataset file_names list
        """
        return len(self.file_names)


    def load_img(self, file_path: str) -> Image.Image:
        """Returns a PIL Image of the image located at `file_path`
        Parameters
        ----------
        file_path : str
            Path to image file to be loaded
        Returns
        -------
        PIL.Image.Image
            Loaded image as PIL Image
        """
        return Image.open(file_path).convert(self.color_space)

robmarkcole avatar May 25 '21 09:05 robmarkcole