pytorch-enhance
pytorch-enhance copied to clipboard
Modified dataset to load from a local dir
Hi @isaaccorley I modified the dataset to load from a local dir (in my case a mounted google drive) - might be one to add to the wiki if not to add to the codebase
class BaseDataset(torch.utils.data.Dataset):
"""Base Super Resolution Dataset Class
"""
color_space: str = "RGB"
lr_transform: T.Compose = None
hr_transform: T.Compose = None
def get_lr_transforms(self):
"""Returns HR to LR image transformations
"""
return Compose([
Resize(size=(
self.image_size//self.scale_factor,
self.image_size//self.scale_factor
),
interpolation=Image.BICUBIC
),
ToTensor(),
])
def get_hr_transforms(self):
"""Returns HR image transformations
"""
return Compose([
Resize((self.image_size, self.image_size), Image.BICUBIC),
ToTensor(),
])
def get_files(self, data_dir: str) -> List[str]:
"""Returns a list of valid image files in a directory
Parameters
----------
root_dir : str
Path to directory of images.
Returns
-------
List[str]
List of valid images in `root_dir` directory.
"""
return glob.glob(data_dir + '*.jpg')
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns a tuple of and lr and hr torch tensors
Parameters
----------
idx : int
Index value to index the list of images
Returns
-------
lr: torch.Tensor
Low Resolution transformed indexed image.
hr: torch.Tensor
High Resolution transformed indexed image.
"""
lr = self.load_img(self.file_names[idx])
hr = lr.copy()
if self.lr_transform:
lr = self.lr_transform(lr)
if self.hr_transform:
hr = self.hr_transform(hr)
return lr, hr
def __len__(self) -> int:
"""Return number of images in dataset
Returns
-------
int
Number of images in dataset file_names list
"""
return len(self.file_names)
def load_img(self, file_path: str) -> Image.Image:
"""Returns a PIL Image of the image located at `file_path`
Parameters
----------
file_path : str
Path to image file to be loaded
Returns
-------
PIL.Image.Image
Loaded image as PIL Image
"""
return Image.open(file_path).convert(self.color_space)