torch-autograd
torch-autograd copied to clipboard
Question about implementing a pairwise L2 distances function
Hi, I'm trying to implement an L2-distances module. The code below prints:
function: 0x41cd5570
3.999683343136
which means the gradient check doesn't pass. What am I missing?
autograd = require 'autograd'
torch.transpose = function(x) return x:t() end
function pdist(embeddings)
local pdist2 = embeddings * torch.transpose(embeddings)
local norm = torch.diag(pdist2):view(pdist2:size(1), 1):expandAs(pdist2)
return norm - 2.0 * pdist2 + torch.transpose(norm)
end
(require 'autograd.overload').module("torch", torch, function(module)
module.gradient("diag", {function(g, ans, x)
return torch.diag(g)
end})
end)
m = autograd.nn.AutoModule('AutoPairwiseL2')(pdist)
require 'nn'
print(nn.Jacobian.testJacobian(m, torch.rand(50, 128)))
P.S. I've had to overload torch.transpose since it doesn't exist in pure Torch
I chatted with @szagoruyko, and he advised to insert autograd.optimize(true) after m = .... It solves the issue, but the root cause is still mysterious.
If the diag support looks fine, I can prepare a PR (along with the test)