SRGAN
SRGAN copied to clipboard
Error in custom dataset yk
Traceback (most recent call last):
File "train.py", line 220, in <module>
train()
File "train.py", line 152, in train
for step, (lr_patch, hr_patch) in enumerate(train_ds):
File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/dataflow/utils.py", line 417, in __next__
data = self._next_data()
File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/dataflow/utils.py", line 438, in _next_data
data = self._dataset_fetcher.fetch(index)
File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/dataflow/utils.py", line 347, in fetch
data = [self.dataset[id] for id in batch_indices]
File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/dataflow/utils.py", line 347, in <listcomp>
data = [self.dataset[id] for id in batch_indices]
File "train.py", line 56, in __getitem__
lr_patch = self.lr_trans(hr_patch)
File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/vision/transforms/transforms.py", line 274, in __call__
return resize(image, self.size, self.interpolation)
File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/vision/transforms/functional.py", line 202, in resize
output = cv2.resize(image, dsize=(size[1], size[0]), interpolation=_cv2_interp_from_str[method])
cv2.error: OpenCV(4.7.0) :-1: error: (-5:Bad argument) in function 'resize'
> Overload resolution failed:
> - src data type = 17 is not supported
> - Expected Ptr<cv::UMat> for argument 'src'
You need to make sure your images format is supported by cv2 resize().
Converting all images to numpy uint8 solved the issue for me.
In getitem(), add the conversion line like this:
def __getitem__(self, index):
img = self.train_hr_imgs[index]
img = img.astype(np.uint8) # add this
hr_patch = self.hr_trans(img)
lr_patch = self.lr_trans(hr_patch)
return nor(lr_patch), nor(hr_patch)