diffcp icon indicating copy to clipboard operation
diffcp copied to clipboard

Sign in adjoint derivative calculation

Open spenrich opened this issue 5 years ago • 3 comments

I've been following the paper "Differentiating Through a Cone Program" and the code side-by-side, and I'm having trouble figuring out if there is a sign error in the adjoint derivative code or if I've misunderstood something.

https://github.com/cvxgrp/diffcp/blob/83080bcd30775e2a48fbac33ca4165c474a7aa00/diffcp/cone_program.py#L341-L357

It seems like, when compared to the paper, the code solves M.T @ r = dz for r, whereas the paper solves M.T @ g = -dz for g. So r = -g. But then the equations used in the code to compute (dA, db, dc) seem to match those in the paper, when they should all differ by a negative sign.

Similarly, for the forward-mode derivative, you solve M @ dz = dQ @ pi_z for dz, use the same equations as in the paper despite the sign difference, but you multiply (dx, dy, dz) by -1 before returning, so this is fine.

Is this a sign error in the adjoint derivative, or did I get something wrong?

spenrich avatar May 18 '20 20:05 spenrich

Good question!

The minus sign makes its way to this line: https://github.com/cvxgrp/diffcp/blob/83080bcd30775e2a48fbac33ca4165c474a7aa00/diffcp/cone_program.py#L352

Because otherwise it would be

values = -pi_z[cols] * r[rows + n] + pi_z[n + rows] * r[cols] 

Sorry that it's a bit confusing. We should probably just move it up to the solve line. This minus sign was certainly a pain for us when we were initially debugging.

sbarratt avatar May 18 '20 20:05 sbarratt

I've included some example code (test_diffcp_vjp.txt), where I've just copied in solve_and_derivative_internal and edited the returned DT function to also output r and pi_z. I then run this on the cone program from the readme, explicitly compute -dQ_12.T + dQ_21 (from the paper) with dQ = np.outer(g, pi_z) (where g = -r), and compare this to the value of dA given by solve_and_derivative_internal. For a fixed random seed, I get:

-dQ_12.T + dQ_21 =
[[-2.8563  0.7865 -1.6522  0.0864  1.2953]
 [-0.6479  0.1784 -0.3748  0.0196  0.2938]
 [-1.5848  0.4364 -0.9167  0.0479  0.7187]
 [-0.      0.     -0.      0.      0.    ]
 [ 0.     -0.      0.     -0.     -0.    ]
 [ 1.5824 -0.4357  0.9153 -0.0479 -0.7176]
 [-0.      0.     -0.      0.      0.    ]
 [ 0.     -0.      0.     -0.     -0.    ]
 [ 0.     -0.      0.     -0.     -0.    ]
 [ 0.     -0.      0.     -0.     -0.    ]
 [ 0.     -0.      0.     -0.     -0.    ]]

dA from solve_and_derivative_internal =
[[ 2.8563 -0.7865  1.6522 -0.0864 -1.2953]
 [ 0.6479 -0.1784  0.3748 -0.0196 -0.2938]
 [ 1.5848 -0.4364  0.9167 -0.0479 -0.7187]
 [ 0.     -0.      0.     -0.     -0.    ]
 [-0.      0.     -0.      0.      0.    ]
 [-1.5824  0.4357 -0.9153  0.0479  0.7176]
 [ 0.     -0.      0.     -0.     -0.    ]
 [-0.      0.     -0.      0.      0.    ]
 [-0.      0.     -0.      0.      0.    ]
 [-0.      0.     -0.      0.      0.    ]
 [-0.      0.     -0.      0.      0.    ]]

There seems to be a difference in sign. Am I computing dQ correctly here?

spenrich avatar May 18 '20 22:05 spenrich

It looks like the paper is wrong but the code is right! We're in the process of updating the equations for dA, db, and dc. They should be

image

Thanks so much for finding this and contacting us about it.

sbarratt avatar May 20 '20 16:05 sbarratt