diffcp icon indicating copy to clipboard operation
diffcp copied to clipboard

Performance squeeze: Don't pre-compute anything for the derivative when initially solving

Open bamos opened this issue 4 years ago • 0 comments

It should be pretty easy to update solve_and_differentiate to not pre-compute anything for the derivative if the user doesn't need this. Otherwise if somebody wants to use this for the derivative sometimes (but not always) then they'll incur some additional performance overhead for the derivative pre-computation even though they don't use it. This also makes the timing results of the forward/backward passes slightly off as time that should be measured in the backward pass is actually present in the forward pass. I think this overhead might even larger for the explicit mode in #2 that calls into cone_lib.dpi_explicit.

I just tried running the following quick example and this part seems to add ~15% overhead

#!/usr/bin/env python3

import numpy as np
from scipy import sparse

import diffcp

nzero = 100
npos = 100
nsoc = 100
m = nzero + npos + nsoc
n = 100

cone_dict = {
    diffcp.ZERO: nzero,
    diffcp.POS: npos,
    diffcp.SOC: [nsoc]
}

A, b, c = diffcp.utils.random_cone_prog(m, n, cone_dict)
x, y, s, D, DT = diffcp.solve_and_derivative(A, b, c, cone_dict)

# evaluate the derivative
nonzeros = A.nonzero()
data = 1e-4 * np.random.randn(A.size)
dA = sparse.csc_matrix((data, nonzeros), shape=A.shape)
db = 1e-4 * np.random.randn(m)
dc = 1e-4 * np.random.randn(n)
dx, dy, ds = D(dA, db, dc)

# evaluate the adjoint of the derivative
dx = c
dy = np.zeros(m)
ds = np.zeros(m)
dA, db, dc = DT(dx, dy, ds)
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    55                                           @profile
    56                                           def solve_and_derivative(A, b, c, cone_dict, warm_start=None, **kwargs):

...

128         1      79706.0  79706.0     84.1      result = scs.solve(data, cone_dict, **kwargs)
129
130                                               # check status
131         1          6.0      6.0      0.0      status = result["info"]["status"]
132         1          4.0      4.0      0.0      if status == "Solved/Innacurate":
133                                                   warnings.warn("Solved/Innacurate.")
134         1          4.0      4.0      0.0      elif status != "Solved":
135                                                   raise SolverError("Solver scs returned status %s" % status)
136
137         1          3.0      3.0      0.0      x = result["x"]
138         1          4.0      4.0      0.0      y = result["y"]
139         1          3.0      3.0      0.0      s = result["s"]
140
141                                               # pre-compute quantities for the derivative
142         1          7.0      7.0      0.0      m, n = A.shape
143         1          4.0      4.0      0.0      N = m + n + 1
144         1         14.0     14.0      0.0      cones = cone_lib.parse_cone_dict(cone_dict)
145         1         21.0     21.0      0.0      z = (x, y - s, np.array([1]))
146         1          4.0      4.0      0.0      u, v, w = z
147         1       1850.0   1850.0      2.0      D_proj_dual_cone = cone_lib.dpi(v, cones, dual=True)
148         1          5.0      5.0      0.0      Q = sparse.bmat([
149         1        271.0    271.0      0.3          [None, A.T, np.expand_dims(c, - 1)],
150         1        299.0    299.0      0.3          [-A, None, np.expand_dims(b, -1)],
151         1       4230.0   4230.0      4.5          [-np.expand_dims(c, -1).T, -np.expand_dims(b, -1).T, None]
152                                               ])
153         1       2878.0   2878.0      3.0      M = splinalg.aslinearoperator(Q - sparse.eye(N)) @ dpi(
154         1       3301.0   3301.0      3.5          z, cones) + splinalg.aslinearoperator(sparse.eye(N))
155         1        445.0    445.0      0.5      pi_z = pi(z, cones)
156         1       1742.0   1742.0      1.8      rows, cols = A.nonzero()

bamos avatar Oct 03 '19 14:10 bamos