contextual_loss_pytorch
contextual_loss_pytorch copied to clipboard
Bug in functions compute_l1_distance and compute_l2_distance
dist = dist.sum(dim=1).abs()
in line 162 in contextual_loss/functional.py is not a l1 distance, correctly
dist = dist.abs().sum(dim=1)
In line 162 in contextual_loss/functional.py you shold transpose matrix A:
dist = y_s - 2 * A.transpose(1, 2) + x_s.transpose(0, 1)
I believe there is also a bug in the compute_l2_distance function. One correct implementation can be:
def compute_l2_distance(x, y):
N, C, H, W = x.size()
x_vec = x.view(N, C, -1)
y_vec = y.view(N, C, -1)
x_s = torch.sum(x_vec ** 2, dim=1, keepdim=True)
y_s = torch.sum(y_vec ** 2, dim=1, keepdim=True)
A = y_vec.transpose(1, 2) @ x_vec
# print(x.shape, y_s.shape, A.shape, x_s.shape)
dist = y_s - 2 * A + x_s.transpose(1, 2)
dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
dist = dist.clamp(min=0.)
return dist
Feel free to point out any potential bugs!
I believe there is also a bug in the compute_l2_distance function. One correct implementation can be:
def compute_l2_distance(x, y): N, C, H, W = x.size() x_vec = x.view(N, C, -1) y_vec = y.view(N, C, -1) x_s = torch.sum(x_vec ** 2, dim=1, keepdim=True) y_s = torch.sum(y_vec ** 2, dim=1, keepdim=True) A = y_vec.transpose(1, 2) @ x_vec # print(x.shape, y_s.shape, A.shape, x_s.shape) dist = y_s - 2 * A + x_s.transpose(1, 2) dist = dist.transpose(1, 2).reshape(N, H*W, H*W) dist = dist.clamp(min=0.) return dist
Feel free to point out any potential bugs! dist = y_s - 2 * A + x_s.transpose(0,1) #change here(0,1) https://github.com/S-aiueo32/contextual_loss_pytorch/issues/6 RuntimeError: The size of tensor a (4096) must match the size of tensor b (2) at non-singleton dimension 1 Good!!! You are right. if I do not correct here, when batch_size >1, it will go wrong.
def compute_l2_distance(x, y): N, C, H, W = x.size() x_vec = x.view(N, C, -1) y_vec = y.view(N, C, -1) x_s = torch.sum(x_vec ** 2, dim=1, keepdim=True) y_s = torch.sum(y_vec ** 2, dim=1, keepdim=True) A = y_vec.transpose(1, 2) @ x_vec # print(x.shape, y_s.shape, A.shape, x_s.shape) dist = y_s - 2 * A + x_s.transpose(1, 2) dist = dist.transpose(1, 2).reshape(N, H*W, H*W) dist = dist.clamp(min=0.) return dist
As mentioned by @DmitryBabichev, A should be transposed. Namely:
dist = y_s - 2 * A.transpose(1, 2) + x_s.transpose(1, 2)