jaxopt
jaxopt copied to clipboard
Assertion error when trying to take jacobian of projected gradient solution using projection_polyhedron
The output of compute_pg is a (12 ,1) 2D array . This is a minimal reproducible example.
import jax.numpy as jnp
import numpy as np
from functools import partial
import jax
from jax import jit
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_polyhedron
num_obs = 6
C_obs_1 = 1*jnp.identity( num_obs)
C_obs_2 = -1*jnp.identity( num_obs)
C_obs = jnp.block([
[ C_obs_1,0*jnp.identity( num_obs)],
[ C_obs_2,0*jnp.identity( num_obs)],
[0*jnp.identity( num_obs), C_obs_1],
[0*jnp.identity( num_obs), C_obs_2]
])
A_obs = jnp.zeros((1,jnp.shape(C_obs)[1]))
a_obstacle = jnp.zeros((1,1))
def compute_obstacle_penalty_temp(p):
cost_obs_penalty = 1.0*jnp.linalg.norm(p)**2
return cost_obs_penalty
def proj(p,C):
return projection_polyhedron(p,C,check_feasible = False)
def compute_pg(p):
p = jnp.reshape(p,(jnp.shape(p)[0],1))
b_obs = jnp.ones((jnp.shape(C_obs)[0],1))
pg = ProjectedGradient(fun= compute_obstacle_penalty_temp,projection= proj,jit=True)
pg_sol = pg.run(p,hyperparams_proj=( A_obs, a_obstacle, C_obs,b_obs)).params
return pg_sol
def compute_bilevel():
return jax.jacobian(compute_pg)(jnp.ones((12,1)))
compute_bilevel()
This is the error I get:
File "/home/ims/ros2_ws/src/mpc_python/mpc_python/plot_test.py", line 41, in <module>
compute_bilevel()
File "/home/ims/ros2_ws/src/mpc_python/mpc_python/plot_test.py", line 39, in compute_bilevel
return jax.jacobian(compute_pg)(jnp.ones((12,1)))
File "/home/ims/.local/lib/python3.8/site-packages/jax/_src/api.py", line 1362, in jacfun
jac = vmap(pullback)(_std_basis(y))
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py", line 236, in solver_fun_bwd
vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py", line 69, in root_vjp
u = solve(matvec, v)
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/linear_solve.py", line 193, in solve_normal_cg
Ab = rmatvec(b) # A.T b
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/linear_solve.py", line 145, in <lambda>
return lambda y: transpose(y)[0]
AssertionError
What I observed was that if I set implicit_diff=False in the ProjectedGradient then it works but is super slow.Kindly advice.
Hi Basant1861
When you differentiate a Jaxopt solver it will attempt - whenever possible - to differentiate with implicit differentiation. Implicit Differentiation is only possible if the argument you are trying to differentiate is part of the optimality conditions of your problem.
That's your issue; you are trying to differentiate with respect to p (i.e the initialization of ProjectedGradient) but it does not appear in the optimality conditions: for a convex problem p doesn't play any role. Hence mathematically the derivative should be zero anyway. Numerically, it's trickier.
Indeed, with implicit differentiation it does not work because Jaxopt cannot handle differentiating with respect to variables that are not part of optimality conditions. Without implicit differentiation, unrolling can return a value. This value should be zero in an ideal world, but with numerical errors I cannot guarantee it (I just tried and it is around 1e-9). If you want to speed it up, you can try wrapping your whole compute_pg function in jax.jit and disable implicit diff.
I suggest you take a step back a think about the meaning of the derivative you want to compute. For example, on non-convex problems the initialization has an importance because different initialization will yield different (local) optima. But in this case the function that maps p_0 to p_t (the optimum) is usually piecewise constant (each piece corresponding to a different basin of attraction).