srntt-pytorch
srntt-pytorch copied to clipboard
TypeError:Input max_val type is not a float. Got <class 'str'>.
I used my mini-dataset. When I ran train.py, I got the error. Later, I found kornia.losses.ssim() parameter type is not correct in class SSIM.
source code: class SSIM(nn.Module): def init(self, window_size=11): super(SSIM, self).init() self.window_size = window_size
def forward(self, x, y):
if x.shape[1] == 3:
x = kornia.color.rgb_to_grayscale(x)
if y.shape[1] == 3:
y = kornia.color.rgb_to_grayscale(y)
return 1 - kornia.losses.ssim(x, y, self.window_size, 'mean')
kornia library requires max_val is float instead of str,i.e.,'mean'. What's wrong with me, please? Or is there a problem with the source code?
My environment is: kornia 0.5.8 python 3.7.0 pytorch 1.7.0
I also met this problem! Could you please tell me how to slove it?
邮件已收到!