cvxpylayers icon indicating copy to clipboard operation
cvxpylayers copied to clipboard

Calling `jax.jacobian` results in error

Open GeoffNN opened this issue 4 years ago • 0 comments

Hi, Thanks for the cool package! I've just started playing around with it on a QP I need to differentiate through. I'm getting the following error when calling jax.jacobian on the layer. The forward pass works fine, and finds a nontrivial solution as indicated by the verbose output of the solver. Any help here would be appreciated!

Solver output:

----------------------------------------------------------------------------
	SCS v2.1.4 - Splitting Conic Solver
	(c) Brendan O'Donoghue, Stanford University, 2012
----------------------------------------------------------------------------
Lin-sys: sparse-direct, nnz in A = 20101
eps = 1.00e-05, alpha = 1.50, max_iters = 5000, normalize = 1, scale = 1.00
acceleration_lookback = 10, rho_x = 1.00e-03
Variables n = 102, constraints m = 202
Cones:	primal zero / dual free vars: 100
	soc vars: 102, soc blks: 1
Setup time: 4.14e-02s
----------------------------------------------------------------------------
 Iter | pri res | dua res | rel gap | pri obj | dua obj | kap/tau | time (s)
----------------------------------------------------------------------------
     0| 2.91e+19  2.29e+19  1.00e+00 -6.29e+19  1.45e+18  5.11e+19  2.14e-02 
   100| 3.48e-05  1.08e-05  5.48e-05  5.81e-02  5.82e-02  3.46e-18  4.10e-02 
   200| 3.62e-03  1.26e-03  1.62e-03  4.83e-02  4.65e-02  2.10e-17  5.60e-02 
   300| 1.37e-01  1.28e-02  3.03e-02  1.30e-02  4.51e-02  2.48e-17  6.72e-02 
   400| 5.11e-02  1.60e-02  2.01e-02  7.96e-02  5.67e-02  2.13e-17  7.94e-02 
   500| 1.51e-02  5.96e-03  5.18e-03  4.17e-02  3.61e-02  7.22e-18  9.34e-02 
   600| 4.03e-04  1.67e-04  1.32e-04  3.91e-02  3.92e-02  3.66e-17  1.05e-01 
   700| 1.25e-03  2.41e-04  1.02e-03  3.82e-02  3.93e-02  1.07e-18  1.18e-01 
   800| 8.28e-03  2.12e-03  2.41e-03  3.25e-02  3.51e-02  2.42e-17  1.30e-01 
   900| 9.51e-04  4.79e-04  1.53e-04  3.87e-02  3.88e-02  3.10e-17  1.42e-01 
  1000| 1.73e-03  1.01e-03  1.14e-03  3.82e-02  3.95e-02  7.56e-18  1.58e-01 
  1100| 2.17e-03  1.14e-03  5.81e-04  3.79e-02  3.86e-02  6.94e-18  1.76e-01 
  1200| 2.57e-02  9.40e-03  5.71e-03  3.29e-02  3.91e-02  4.66e-17  1.93e-01 
  1300| 5.11e-03  2.51e-03  7.59e-03  4.72e-02  3.90e-02  5.41e-18  2.07e-01 
  1400| 1.46e-03  3.23e-04  5.90e-06  3.68e-02  3.68e-02  3.53e-18  2.19e-01 
  1500| 2.58e-03  2.64e-04  1.59e-03  3.64e-02  3.81e-02  5.11e-18  2.36e-01 
  1600| 1.07e-03  5.72e-04  2.59e-04  3.77e-02  3.80e-02  1.17e-17  2.49e-01 
  1700| 1.21e-03  5.51e-04  9.79e-04  3.82e-02  3.72e-02  2.62e-17  2.63e-01 
  1800| 6.70e-04  1.30e-04  2.71e-04  3.71e-02  3.74e-02  3.90e-17  2.76e-01 
  1900| 2.35e-02  1.15e-02  4.02e-03  4.06e-02  3.62e-02  3.50e-17  2.88e-01 
  2000| 1.21e-03  1.82e-04  1.99e-03  3.83e-02  3.61e-02  1.90e-17  3.01e-01 
  2100| 4.19e-02  1.53e-02  1.78e-02  5.94e-02  3.98e-02  4.11e-19  3.12e-01 
  2200| 9.61e-04  3.36e-04  5.88e-04  3.86e-02  3.79e-02  1.78e-17  3.27e-01 
  2300| 1.10e-03  5.49e-04  7.74e-04  3.86e-02  3.78e-02  2.92e-17  3.40e-01 
  2400| 5.45e-03  1.52e-03  2.36e-03  4.30e-02  4.04e-02  7.82e-17  3.57e-01 
  2500| 1.66e-02  9.82e-03  1.03e-02  5.18e-02  4.06e-02  3.10e-17  3.69e-01 
  2600| 1.31e-04  7.45e-05  9.50e-05  3.75e-02  3.76e-02  1.97e-17  3.86e-01 
  2700| 4.12e-04  2.23e-04  4.29e-04  3.79e-02  3.75e-02  5.47e-18  3.99e-01 
  2800| 2.99e-04  9.80e-05  3.23e-05  3.76e-02  3.76e-02  1.43e-17  4.11e-01 
  2900| 1.62e-03  9.50e-04  2.36e-03  3.44e-02  3.69e-02  2.24e-17  4.23e-01 
  3000| 5.68e-04  3.09e-04  3.79e-04  3.79e-02  3.75e-02  4.01e-18  4.35e-01 
  3100| 1.02e-02  5.00e-03  1.26e-03  4.00e-02  4.14e-02  2.91e-17  4.47e-01 
  3200| 9.33e-03  4.59e-03  3.19e-03  4.10e-02  3.75e-02  6.83e-18  4.60e-01 
  3300| 1.20e-03  6.14e-04  7.47e-04  3.67e-02  3.75e-02  9.50e-18  4.72e-01 
  3400| 4.41e-02  2.47e-02  2.23e-02  1.80e-02  4.17e-02  2.27e-17  4.84e-01 
  3500| 6.01e-03  3.32e-03  2.26e-03  2.76e-02  2.52e-02  5.72e-17  4.96e-01 
  3600| 7.45e-04  8.20e-05  9.78e-05  3.70e-02  3.69e-02  1.19e-17  5.08e-01 
  3700| 1.16e-04  4.97e-05  1.93e-04  3.72e-02  3.70e-02  2.38e-17  5.20e-01 
  3800| 3.30e-03  1.48e-03  5.65e-04  3.74e-02  3.67e-02  6.22e-18  5.32e-01 
  3900| 7.25e-03  3.55e-03  1.21e-03  3.47e-02  3.35e-02  3.64e-18  5.44e-01 
  4000| 1.61e-03  3.41e-04  4.68e-04  3.70e-02  3.65e-02  2.64e-17  5.60e-01 
  4100| 5.24e-02  1.16e-02  1.15e-02  6.02e-02  4.74e-02  6.24e-18  5.73e-01 
  4200| 4.50e-03  2.61e-03  1.55e-03  4.08e-02  3.91e-02  1.08e-18  5.84e-01 
  4300| 7.94e-03  4.51e-03  8.34e-04  3.66e-02  3.75e-02  3.63e-17  6.00e-01 
  4400| 3.09e-03  1.80e-03  2.32e-03  3.97e-02  4.22e-02  3.98e-17  6.12e-01 
  4500| 6.47e-04  2.21e-04  1.00e-03  3.60e-02  3.70e-02  6.48e-18  6.24e-01 
  4600| 1.02e-02  4.87e-03  7.15e-03  4.29e-02  3.51e-02  1.17e-17  6.35e-01 
  4700| 6.71e-04  3.42e-04  2.42e-04  3.68e-02  3.71e-02  1.07e-18  6.48e-01 
  4800| 2.58e-03  1.50e-03  2.71e-04  3.53e-02  3.56e-02  1.59e-17  6.59e-01 
  4900| 2.89e-01  1.71e-01  7.00e-02 -1.56e-01 -6.99e-02  1.67e-17  6.71e-01 
  5000| 6.05e-03  2.44e-03  7.90e-04  3.51e-02  3.59e-02  1.61e-17  6.83e-01 
----------------------------------------------------------------------------
Status: Solved/Inaccurate
Hit max_iters, solution may be inaccurate, returning best found solution.
Timing: Solve time: 6.83e-01s
	Lin-sys: nnz in L factor: 25455, avg solve time: 7.81e-05s
	Cones: avg projection time: 3.44e-07s
	Acceleration: avg step time: 3.39e-05s
----------------------------------------------------------------------------
Error metrics:
dist(s, K) = 5.0672e-20, dist(y, K*) = 0.0000e+00, s'y/|s||y| = 3.5284e-16
primal res: |Ax + s - b|_2 / (1 + |b|_2) = 1.8102e-05
dual res:   |A'y + c|_2 / (1 + |c|_2) = 4.0469e-06
rel gap:    |c'x + b'y| / (1 + |c'x| + |b'y|) = 9.8806e-06
----------------------------------------------------------------------------
c'x = 0.0391, -b'y = 0.0391
============================================================================
----------------------------------------------------------------------------
	SCS v2.1.4 - Splitting Conic Solver
	(c) Brendan O'Donoghue, Stanford University, 2012
----------------------------------------------------------------------------
Lin-sys: sparse-direct, nnz in A = 20101
eps = 1.00e-05, alpha = 1.50, max_iters = 5000, normalize = 1, scale = 1.00
acceleration_lookback = 0, rho_x = 1.00e-03
Variables n = 102, constraints m = 202
Cones:	primal zero / dual free vars: 100
	soc vars: 102, soc blks: 1
Setup time: 6.42e-02s
----------------------------------------------------------------------------
 Iter | pri res | dua res | rel gap | pri obj | dua obj | kap/tau | time (s)
----------------------------------------------------------------------------
     0| 2.91e+19  2.29e+19  1.00e+00 -6.29e+19  1.45e+18  5.11e+19  5.68e-02 
   100| 3.87e-05  2.47e-07  5.24e-09  5.82e-02  5.82e-02  2.21e-17  7.08e-02 
   140| 6.52e-06  1.66e-07  5.32e-09  5.82e-02  5.82e-02  2.23e-17  7.57e-02 
----------------------------------------------------------------------------
Status: Solved
Timing: Solve time: 7.57e-02s
	Lin-sys: nnz in L factor: 25455, avg solve time: 1.01e-04s
	Cones: avg projection time: 7.41e-07s
	Acceleration: avg step time: 8.07e-08s
----------------------------------------------------------------------------
Error metrics:
dist(s, K) = 4.4318e-19, dist(y, K*) = 1.1102e-16, s'y/|s||y| = 3.7957e-17
primal res: |Ax + s - b|_2 / (1 + |b|_2) = 6.5182e-06
dual res:   |A'y + c|_2 / (1 + |c|_2) = 1.6561e-07
rel gap:    |c'x + b'y| / (1 + |c'x| + |b'y|) = 5.3244e-09
----------------------------------------------------------------------------
c'x = 0.0582, -b'y = 0.0582
============================================================================

Error:

/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in jacfun(*args, **kwargs)
    970     tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
    971     pushfwd = partial(_jvp, f_partial, dyn_args)
--> 972     y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
    973     tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
    974     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args

<ipython-input-60-2f66503cf065> in <lambda>(*args, **kwargs)
     16   return lambda *args, **kwargs: CvxpyLayer(prob,
     17                     parameters=[_F, _dx_F, _dt_F, _ground_truth],
---> 18                     variables=[_lam])(*args, **kwargs)[0]

/usr/local/lib/python3.7/dist-packages/cvxpylayers/jax/cvxpylayer.py in CvxpyLayer(problem, parameters, variables, gp)
     74         )
     75     else:
---> 76         data, _, _ = problem.get_problem_data(solver=cp.SCS)
     77         compiler = data[cp.settings.PARAM_PROB]
     78         param_ids = [p.id for p in param_order]

/usr/local/lib/python3.7/dist-packages/cvxpy/problems/problem.py in get_problem_data(self, solver, gp, enforce_dpp, verbose)
    599 
    600             data, solver_inverse_data = solving_chain.solver.apply(
--> 601                 self._cache.param_prog)
    602             inverse_data = self._cache.inverse_data + [solver_inverse_data]
    603             self._compilation_time = time.time() - start

/usr/local/lib/python3.7/dist-packages/cvxpy/reductions/solvers/conic_solvers/scs_conif.py in apply(self, problem)
    219         # Apply parameter values.
    220         # Obtain A, b such that Ax + s = b, s \in cones.
--> 221         c, d, A, b = problem.apply_parameters()
    222         data[s.C] = c
    223         inv_data[s.OFFSET] = d

/usr/local/lib/python3.7/dist-packages/cvxpy/reductions/dcp2cone/cone_matrix_stuffing.py in apply_parameters(self, id_to_param_value, zero_offset, keep_zeros)
    203             self.reduced_A, param_vec, self.x.size,
    204             nonzero_rows=self._A_mapping_nonzero, with_offset=True,
--> 205             problem_data_index=self.problem_data_index)
    206         return c, d, A, np.atleast_1d(b)
    207 

/usr/local/lib/python3.7/dist-packages/cvxpy/cvxcore/python/canonInterface.py in get_matrix_from_tensor(problem_data_tensor, param_vec, var_length, nonzero_rows, with_offset, problem_data_index)
    236     if nonzero_rows is not None and nonzero_rows.size > 0:
    237         A_nrows, _ = A.shape
--> 238         A_rows, A_cols = nonzero_csc_matrix(A)
    239         A_vals = np.append(A.data, np.zeros(nonzero_rows.size))
    240         A_rows = np.append(A_rows, nonzero_rows % A_nrows)

/usr/local/lib/python3.7/dist-packages/cvxpy/cvxcore/python/canonInterface.py in nonzero_csc_matrix(A)
    144     # this function returns (rows, cols) corresponding to nonzero entries in
    145     # A; an entry that is explicitly set to zero is treated as nonzero
--> 146     assert not np.isnan(A.data).any()
    147 
    148     # scipy drops rows, cols with explicit zeros; use nan as a sentinel

AssertionError: 

GeoffNN avatar Sep 30 '21 00:09 GeoffNN