pytorch-loss-functions
pytorch-loss-functions copied to clipboard
fix PSNR error for averaging all batches
Hi, thanks for the great work here!
However, i have noticed some thing weird when using PSNR()
class.
the original implementation computes MSE over all batches of image, and calculate the PSNR over it:
10log_10(255/MSE_avg)
, where MSE_avg = mean(MSE)
over [B, C, H, W]
to calculate the mean PSNR, it should be:
mean(10log_10(255/MSE_i))
over each batch, where MSE_i
is MSE over each [C,H,W]
notice that the 2 approaches are: compute mean inside and outside log function, which is not equivalent in terms of math.
@styler00dollar
I didn't see this pr until I randomly checked this repo, thanks.