jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Assertion error when trying to take jacobian of projected gradient solution using projection_polyhedron

Open Basant1861 opened this issue 3 years ago • 1 comments

The output of compute_pg is a (12 ,1) 2D array . This is a minimal reproducible example.

import jax.numpy as jnp
import numpy as np
from functools import partial
import jax
from jax import jit
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_polyhedron

num_obs = 6
C_obs_1 = 1*jnp.identity( num_obs)
C_obs_2 = -1*jnp.identity( num_obs)
C_obs = jnp.block([
                  [ C_obs_1,0*jnp.identity( num_obs)],
                  [ C_obs_2,0*jnp.identity( num_obs)],
                  [0*jnp.identity( num_obs), C_obs_1],
                  [0*jnp.identity( num_obs), C_obs_2]
                  ])
        
A_obs = jnp.zeros((1,jnp.shape(C_obs)[1]))
a_obstacle = jnp.zeros((1,1))

def compute_obstacle_penalty_temp(p):
    cost_obs_penalty = 1.0*jnp.linalg.norm(p)**2
    return cost_obs_penalty 

def proj(p,C):
    return projection_polyhedron(p,C,check_feasible = False)

def compute_pg(p):
    
    p = jnp.reshape(p,(jnp.shape(p)[0],1))
    b_obs = jnp.ones((jnp.shape(C_obs)[0],1))
    
    pg = ProjectedGradient(fun= compute_obstacle_penalty_temp,projection= proj,jit=True)
    pg_sol = pg.run(p,hyperparams_proj=( A_obs, a_obstacle, C_obs,b_obs)).params
    return pg_sol

def compute_bilevel():
  return jax.jacobian(compute_pg)(jnp.ones((12,1)))

compute_bilevel()

This is the error I get:

File "/home/ims/ros2_ws/src/mpc_python/mpc_python/plot_test.py", line 41, in <module>
  compute_bilevel()
File "/home/ims/ros2_ws/src/mpc_python/mpc_python/plot_test.py", line 39, in compute_bilevel
  return jax.jacobian(compute_pg)(jnp.ones((12,1)))
File "/home/ims/.local/lib/python3.8/site-packages/jax/_src/api.py", line 1362, in jacfun
  jac = vmap(pullback)(_std_basis(y))
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py", line 236, in solver_fun_bwd
  vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py", line 69, in root_vjp
  u = solve(matvec, v)
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/linear_solve.py", line 193, in solve_normal_cg
  Ab = rmatvec(b)  # A.T b
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/linear_solve.py", line 145, in <lambda>
  return lambda y: transpose(y)[0]
AssertionError

What I observed was that if I set implicit_diff=False in the ProjectedGradient then it works but is super slow.Kindly advice.

Basant1861 avatar Nov 02 '22 11:11 Basant1861

Hi Basant1861

When you differentiate a Jaxopt solver it will attempt - whenever possible - to differentiate with implicit differentiation. Implicit Differentiation is only possible if the argument you are trying to differentiate is part of the optimality conditions of your problem.

That's your issue; you are trying to differentiate with respect to p (i.e the initialization of ProjectedGradient) but it does not appear in the optimality conditions: for a convex problem p doesn't play any role. Hence mathematically the derivative should be zero anyway. Numerically, it's trickier.

Indeed, with implicit differentiation it does not work because Jaxopt cannot handle differentiating with respect to variables that are not part of optimality conditions. Without implicit differentiation, unrolling can return a value. This value should be zero in an ideal world, but with numerical errors I cannot guarantee it (I just tried and it is around 1e-9). If you want to speed it up, you can try wrapping your whole compute_pg function in jax.jit and disable implicit diff.

I suggest you take a step back a think about the meaning of the derivative you want to compute. For example, on non-convex problems the initialization has an importance because different initialization will yield different (local) optima. But in this case the function that maps p_0 to p_t (the optimum) is usually piecewise constant (each piece corresponding to a different basin of attraction).

Algue-Rythme avatar Nov 11 '22 17:11 Algue-Rythme