diffcp
diffcp copied to clipboard
Performance squeeze: Don't pre-compute anything for the derivative when initially solving
It should be pretty easy to update solve_and_differentiate
to not pre-compute anything for the derivative if the user doesn't need this. Otherwise if somebody wants to use this for the derivative sometimes (but not always) then they'll incur some additional performance overhead for the derivative pre-computation even though they don't use it. This also makes the timing results of the forward/backward passes slightly off as time that should be measured in the backward pass is actually present in the forward pass. I think this overhead might even larger for the explicit mode in #2 that calls into cone_lib.dpi_explicit
.
I just tried running the following quick example and this part seems to add ~15% overhead
#!/usr/bin/env python3
import numpy as np
from scipy import sparse
import diffcp
nzero = 100
npos = 100
nsoc = 100
m = nzero + npos + nsoc
n = 100
cone_dict = {
diffcp.ZERO: nzero,
diffcp.POS: npos,
diffcp.SOC: [nsoc]
}
A, b, c = diffcp.utils.random_cone_prog(m, n, cone_dict)
x, y, s, D, DT = diffcp.solve_and_derivative(A, b, c, cone_dict)
# evaluate the derivative
nonzeros = A.nonzero()
data = 1e-4 * np.random.randn(A.size)
dA = sparse.csc_matrix((data, nonzeros), shape=A.shape)
db = 1e-4 * np.random.randn(m)
dc = 1e-4 * np.random.randn(n)
dx, dy, ds = D(dA, db, dc)
# evaluate the adjoint of the derivative
dx = c
dy = np.zeros(m)
ds = np.zeros(m)
dA, db, dc = DT(dx, dy, ds)
Line # Hits Time Per Hit % Time Line Contents
==============================================================
55 @profile
56 def solve_and_derivative(A, b, c, cone_dict, warm_start=None, **kwargs):
...
128 1 79706.0 79706.0 84.1 result = scs.solve(data, cone_dict, **kwargs)
129
130 # check status
131 1 6.0 6.0 0.0 status = result["info"]["status"]
132 1 4.0 4.0 0.0 if status == "Solved/Innacurate":
133 warnings.warn("Solved/Innacurate.")
134 1 4.0 4.0 0.0 elif status != "Solved":
135 raise SolverError("Solver scs returned status %s" % status)
136
137 1 3.0 3.0 0.0 x = result["x"]
138 1 4.0 4.0 0.0 y = result["y"]
139 1 3.0 3.0 0.0 s = result["s"]
140
141 # pre-compute quantities for the derivative
142 1 7.0 7.0 0.0 m, n = A.shape
143 1 4.0 4.0 0.0 N = m + n + 1
144 1 14.0 14.0 0.0 cones = cone_lib.parse_cone_dict(cone_dict)
145 1 21.0 21.0 0.0 z = (x, y - s, np.array([1]))
146 1 4.0 4.0 0.0 u, v, w = z
147 1 1850.0 1850.0 2.0 D_proj_dual_cone = cone_lib.dpi(v, cones, dual=True)
148 1 5.0 5.0 0.0 Q = sparse.bmat([
149 1 271.0 271.0 0.3 [None, A.T, np.expand_dims(c, - 1)],
150 1 299.0 299.0 0.3 [-A, None, np.expand_dims(b, -1)],
151 1 4230.0 4230.0 4.5 [-np.expand_dims(c, -1).T, -np.expand_dims(b, -1).T, None]
152 ])
153 1 2878.0 2878.0 3.0 M = splinalg.aslinearoperator(Q - sparse.eye(N)) @ dpi(
154 1 3301.0 3301.0 3.5 z, cones) + splinalg.aslinearoperator(sparse.eye(N))
155 1 445.0 445.0 0.5 pi_z = pi(z, cones)
156 1 1742.0 1742.0 1.8 rows, cols = A.nonzero()