torch_sparse_solve
torch_sparse_solve copied to clipboard
fail to pass the `torch.autograd.gradcheck`
Thank you very much for providing such a good tool!
My problem is that when the input A
is a 'real' sparse matrix, not the sparse matrix converted from a dense matrix, the torch.autograd.gradcheck()
function will throw an exception. The python
program I use is
import scipy
import torch
class SparseSolve(torch.autograd.Function):
@staticmethod
def forward(ctx, A, b):
'''
A is a torch coo sparse matrix
b is a tensor
'''
if A.ndim != 2 or (A.shape[0] != A.shape[1]):
raise ValueError("A should be a square 2D matrix.")
A = A.coalesce()
A_idx = A.indices().to('cpu').numpy()
A_val = A.values().to('cpu').numpy()
sci_A = coo_matrix((A_val,(A_idx[0,:],A_idx[1,:]) ),shape=A.shape)
sci_A = sci_A.tocsr()
np_b = b.detach().cpu().numpy()
# Solver the sparse system
if np_b.ndim == 1:
np_x = scipy.sparse.linalg.spsolve(sci_A, np_b)
else:
factorisedsolver = scipy.sparse.linalg.factorized(sci_A)
np_x = factorisedsolver(np_b)
x = torch.as_tensor(np_x)
# Not sure if the following is needed / helpful
if A.requires_grad or b.requires_grad:
x.requires_grad = True
# Save context for backward pass
ctx.save_for_backward(A, b, x)
return x
@staticmethod
def backward(ctx, grad):
# Recover context
A, b, x = ctx.saved_tensors
# Compute gradient with respect to b
gradb = SparseSolve.apply(A.t(), grad)
gradAidx = A.indices()
mgradbselect = -gradb.index_select(0,gradAidx[0,:])
xselect = x.index_select(0,gradAidx[1,:])
mgbx = mgradbselect * xselect
if x.dim() == 1:
gradAvals = mgbx
else:
gradAvals = torch.sum( mgbx, dim=1 )
gradA = torch.sparse_coo_tensor(gradAidx, gradAvals, A.shape)
return gradA, gradb
sparsesolve = SparseSolve.apply
row_vec = torch.tensor([0, 0, 1, 2])
col_vec = torch.tensor([0, 2, 1, 2])
val_vec = torch.tensor([3.0, 4.0, 5.0, 6.0],dtype=torch.float64)
A = torch.sparse_coo_tensor(torch.stack((row_vec,col_vec),0), val_vec, (3, 3))
b = torch.ones(3, dtype=torch.float64, requires_grad=False)
A.requires_grad=True
b.requires_grad=True
res = torch.autograd.gradcheck(sparsesolve, [A, b], raise_exception=True)
print(res)
which is based on the program from Differentiable sparse linear solver with cupy backend - “unsupported tensor layout: Sparse” in gradcheck, whose author @tvercaut wrote the program based on your blog and program. I modified the program and limited it to running only on CPU
.
The output is
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[-3.7037e-02, 0.0000e+00, -1.3878e-11],
[-6.6667e-02, 0.0000e+00, 0.0000e+00],
[-5.5556e-02, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -2.2222e-02, 0.0000e+00],
[ 0.0000e+00, -4.0000e-02, 0.0000e+00],
[ 0.0000e+00, -3.3333e-02, 0.0000e+00],
[ 2.4691e-02, 0.0000e+00, -1.8519e-02],
[ 4.4444e-02, 0.0000e+00, -3.3333e-02],
[ 3.7037e-02, 0.0000e+00, -2.7778e-02]], dtype=torch.float64)
analytical:tensor([[-0.0370, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[-0.0556, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0400, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0370, 0.0000, -0.0278]], dtype=torch.float64)
The success of your and @tvercaut's program in passing the gradient check can be attributed to the fact that the sparse matrix A
you used is actually a dense matrix. Consequently, the autograd()
function computes the gradient for each element.
The derivative formula from your blog is
$$\frac{\partial L}{\partial A} = - \frac{\partial L}{\partial b} \otimes x$$
Since the matrix A
is sparse, then $\frac{\partial L}{\partial A_{ij}}=0$ when $A_{ij}=0$, but the results computed by pytorch
show it's not true. If I change backward()
function into
def backward(ctx, grad):
A, b, x = ctx.saved_tensors
gradb = SparseSolve.apply(A.t(), grad)
if x.ndim == 1:
gradA = -gradb.reshape(-1,1) @ x.reshape(1,-1)
else:
gradA = -gradb @ x.T
Then the gradient check is passed. However the gradA
is now a dense matrix, which is not consistent to the theoretical result. There is a similar issue #13 without detailed explanation. So I want to ask which gradient is right ? the sparse one or the dense one?