jaxopt
jaxopt copied to clipboard
document that lower and upper bounds in Bisection need stop_gradient for differentiation
Hello, I'm trying to work with the following snippet,
import numpy as np
import jax.numpy as jnp
import jax
from jaxopt import Bisection
@jax.jit
def _xy_c(r, phi, spin, theta_o):
lam = spin + r / spin * (r - (2 * (r**2 - 2 * r + spin**2)) / (r - 1))
eta = r**3 / spin**2 *((4 * (r**2 - 2 * r + spin**2)) / (r - 1)**2 - r)
alpha = -lam / jnp.sin(theta_o)
beta = eta + spin**2 * jnp.cos(theta_o)**2 - lam**2 * jnp.tan(theta_o)**(-2)
beta = jnp.sign(beta) * jnp.sqrt(jnp.abs(beta))
return alpha, beta
@jax.jit
def _r_c_solve(r, phi, spin, theta_o):
alpha, beta = _xy_c(r, phi, spin, theta_o)
return (jnp.arctan2(beta, alpha) * 180. / jnp.pi + 90) % 360 - 90 - phi * 180. / jnp.pi
def r_c_solve(phi, spin, theta_o):
phi = phi * jnp.pi / 180.
theta_o = theta_o * jnp.pi / 180.
theta_o = jnp.clip(theta_o, 1e-5, jnp.pi - 1e-5)
r_m = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(-spin)))
r_p = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(spin)))
r_0 = r_m - 0.0001 * (r_p - r_m)
r_1 = r_p + 0.0001 * (r_p - r_m)
return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params
I think usually it does not matter whether one jit the intermediate functions, i.e. jit(A(B))
is the same as jit(A(jit(B)))
. However, I find this no longer the case when jaxopt.Bisection
is involved. For example, the following g_r_c_solve_0
and g_r_c_solve_1
works well,
g_r_c_solve_0 = jax.grad(r_c_solve)
%time g_r_c_solve_0(10., 0.9375, 163)
g_r_c_solve_1 = jax.jit(jax.grad(r_c_solve))
%time g_r_c_solve_1(10., 0.9375, 163)
But g_r_c_solve_2
and g_r_c_solve_3
will give me an UnexpectedTracerError
,
g_r_c_solve_2 = jax.grad(jax.jit(r_c_solve))
%time g_r_c_solve_2(10., 0.9375, 163)
g_r_c_solve_3 = jax.jit(jax.grad(jax.jit(r_c_solve)))
%time g_r_c_solve_3(10., 0.9375, 163)
with the full error message below,
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel_launcher.py:17
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/traitlets/config/application.py:1043, in launch_instance()
1042 app.initialize(argv)
-> 1043 app.start()
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelapp.py:728, in start()
727 try:
--> 728 self.io_loop.start()
729 except KeyboardInterrupt:
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/tornado/platform/asyncio.py:195, in start()
194 def start(self) -> None:
--> 195 self.asyncio_loop.run_forever()
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/asyncio/base_events.py:607, in run_forever()
606 while True:
--> 607 self._run_once()
608 if self._stopping:
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/asyncio/base_events.py:1922, in _run_once()
1921 else:
-> 1922 handle._run()
1923 handle = None
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/asyncio/events.py:80, in _run()
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelbase.py:516, in dispatch_queue()
515 try:
--> 516 await self.process_one()
517 except Exception:
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelbase.py:505, in process_one()
504 return None
--> 505 await dispatch(*args)
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelbase.py:412, in dispatch_shell()
411 if inspect.isawaitable(result):
--> 412 await result
413 except Exception:
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelbase.py:740, in execute_request()
739 if inspect.isawaitable(reply_content):
--> 740 reply_content = await reply_content
742 # Flush output before sending the reply.
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/ipkernel.py:422, in do_execute()
421 if with_cell_id:
--> 422 res = shell.run_cell(
423 code,
424 store_history=store_history,
425 silent=silent,
426 cell_id=cell_id,
427 )
428 else:
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/zmqshell.py:540, in run_cell()
539 self._last_traceback = None
--> 540 return super().run_cell(*args, **kwargs)
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3009, in run_cell()
3008 try:
-> 3009 result = self._run_cell(
3010 raw_cell, store_history, silent, shell_futures, cell_id
3011 )
3012 finally:
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3064, in _run_cell()
3063 try:
-> 3064 result = runner(coro)
3065 except BaseException as e:
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3269, in run_cell_async()
3266 interactivity = "none" if silent else self.ast_node_interactivity
-> 3269 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3270 interactivity=interactivity, compiler=compiler, result=result)
3272 self.last_execution_succeeded = not has_raised
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3448, in run_ast_nodes()
3447 asy = compare(code)
-> 3448 if await self.run_code(code, result, async_=asy):
3449 return True
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3508, in run_code()
3507 else:
-> 3508 exec(code_obj, self.user_global_ns, self.user_ns)
3509 finally:
3510 # Reset our crash handler in place
Cell In[10], line 1
----> 1 get_ipython().run_line_magic('time', 'g_r_c_solve_2(10., 0.9375, 163)')
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:2417, in run_line_magic()
2416 with self.builtin_trap:
-> 2417 result = fn(*args, **kwargs)
2419 # The code below prevents the output from being displayed
2420 # when using magics with decodator @output_can_be_silenced
2421 # when the last Python token in the expression is a ';'.
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/magics/execution.py:1317, in time()
1316 try:
-> 1317 out = eval(code, glob, local_ns)
1318 except:
File <timed eval>:1
Cell In[2], line 24, in r_c_solve()
22 r_1 = r_p + 0.0001 * (r_p - r_m)
23 return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
---> 24 check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jaxopt/_src/bisection.py:158, in run()
153 def run(self,
154 init_params: Optional[Any] = None,
155 *args,
156 **kwargs) -> base.OptStep:
157 # We override run in order to set init_params=None by default.
--> 158 return super().run(init_params, *args, **kwargs)
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jaxopt/_src/base.py:354, in run()
352 run = decorator(run)
--> 354 return run(init_params, *args, **kwargs)
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jaxopt/_src/implicit_diff.py:251, in wrapped_solver_fun()
250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
JaxStackTraceBeforeTransformation: jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was r_c_solve at /var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:15 traced for jit.
------------------------------
The leaked intermediate value was created on line /var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:21:10 (r_c_solve).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<frozen runpy>:198:11 (_run_module_as_main)
<frozen runpy>:88:4 (_run_code)
/var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/1728866946.py:1 (<module>)
<timed eval>:1 (<module>)
/var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:21:10 (r_c_solve)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
UnexpectedTracerError Traceback (most recent call last)
File <timed eval>:1
[... skipping hidden 33 frame]
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jaxopt/_src/implicit_diff.py:210, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_fwd(*flat_args)
209 def solver_fun_fwd(*flat_args):
--> 210 res = solver_fun_flat(*flat_args)
211 return res, (res, flat_args)
[... skipping hidden 4 frame]
File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1835, in DynamicJaxprTrace.getvar(self, tracer)
1833 var = self.frame.tracer_to_var.get(id(tracer))
1834 if var is None:
-> 1835 raise core.escaped_tracer_error(tracer)
1836 return var
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was r_c_solve at /var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:15 traced for jit.
------------------------------
The leaked intermediate value was created on line /var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:21:10 (r_c_solve).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<frozen runpy>:198:11 (_run_module_as_main)
<frozen runpy>:88:4 (_run_code)
/var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/1728866946.py:1 (<module>)
<timed eval>:1 (<module>)
/var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:21:10 (r_c_solve)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
It turns out that I cannot take gradients of already jitted functions. Is it possible to fix this issue?
FYI, I'm using jax=0.4.13
, jaxlib=0.4.13
, jaxopt=0.7
and python=3.11.4
.
Hello @hjia,
Thanks for reporting this! The issue is that lower
and upper
need to be static arguments for making the implicit differentiation compatible with jit (in the solve_2
and solve_3
cases). So a possible workaround is presented below. We could also change the signature of the Bisection method to let lower
and upper
be taken as parameters. But before changing this, could you give us a bit more context on what you want to do? A priori grad already traces the function such that you would not need to jit it before taking the gradient.
import numpy as np
import jax.numpy as jnp
import numpy as np
import jax
from jaxopt import Bisection
from functools import partial
@jax.jit
def _xy_c(r, phi, spin, theta_o):
lam = spin + r / spin * (r - (2 * (r**2 - 2 * r + spin**2)) / (r - 1))
eta = r**3 / spin**2 *((4 * (r**2 - 2 * r + spin**2)) / (r - 1)**2 - r)
alpha = -lam / jnp.sin(theta_o)
beta = eta + spin**2 * jnp.cos(theta_o)**2 - lam**2 * jnp.tan(theta_o)**(-2)
beta = jnp.sign(beta) * jnp.sqrt(jnp.abs(beta))
return alpha, beta
@jax.jit
def _r_c_solve(r, phi, spin, theta_o):
alpha, beta = _xy_c(r, phi, spin, theta_o)
return (jnp.arctan2(beta, alpha) * 180. / jnp.pi + 90) % 360 - 90 - phi * 180. / jnp.pi
def r_c_solve(phi, spin, theta_o):
phi = phi * jnp.pi / 180.
theta_o = theta_o * jnp.pi / 180.
theta_o = jnp.clip(theta_o, 1e-5, jnp.pi - 1e-5)
r_m = 2 * (1 + np.cos(2 / 3 * np.arccos(-spin)))
r_p = 2 * (1 + np.cos(2 / 3 * np.arccos(spin)))
r_0 = r_m - 0.0001 * (r_p - r_m)
r_1 = r_p + 0.0001 * (r_p - r_m)
return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params
g_r_c_solve_0 = jax.grad(r_c_solve)
g_r_c_solve_0(10., spin=0.9375, theta_o=163)
g_r_c_solve_1 = jax.jit(jax.grad(r_c_solve), static_argnames='spin')
g_r_c_solve_1(10., spin=0.9375, theta_o=163)
g_r_c_solve_2 = jax.grad(jax.jit(r_c_solve, static_argnames='spin'))
g_r_c_solve_2(10., spin=0.9375, theta_o=163)
g_r_c_solve_3 = jax.jit(jax.grad(jax.jit(r_c_solve, static_argnames='spin')))
g_r_c_solve_2(10., spin=0.9375, theta_o=163)
We could want to make the Bisection method differentiable with respect to its lower and upper values. The following code fails for example. But so it would be nice to have a use case for us to rethink the implementation of Bisection.
import numpy as np
import jax.numpy as jnp
import numpy as np
import jax
from jaxopt import Bisection
from functools import partial
@jax.jit
def _xy_c(r, phi, spin, theta_o):
lam = spin + r / spin * (r - (2 * (r**2 - 2 * r + spin**2)) / (r - 1))
eta = r**3 / spin**2 *((4 * (r**2 - 2 * r + spin**2)) / (r - 1)**2 - r)
alpha = -lam / jnp.sin(theta_o)
beta = eta + spin**2 * jnp.cos(theta_o)**2 - lam**2 * jnp.tan(theta_o)**(-2)
beta = jnp.sign(beta) * jnp.sqrt(jnp.abs(beta))
return alpha, beta
@jax.jit
def _r_c_solve(r, phi, spin, theta_o):
alpha, beta = _xy_c(r, phi, spin, theta_o)
return (jnp.arctan2(beta, alpha) * 180. / jnp.pi + 90) % 360 - 90 - phi * 180. / jnp.pi
def r_c_solve(spin, phi, theta_o):
phi = phi * jnp.pi / 180.
theta_o = theta_o * jnp.pi / 180.
theta_o = jnp.clip(theta_o, 1e-5, jnp.pi - 1e-5)
r_m = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(-spin)))
r_p = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(spin)))
r_0 = r_m - 0.0001 * (r_p - r_m)
r_1 = r_p + 0.0001 * (r_p - r_m)
return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params
g_r_c_solve_0 = jax.grad(r_c_solve)
g_r_c_solve_0(0.9375, 10., 63)
We could also change the signature of the Bisection method to let lower and upper be taken as parameters
I don't think we can. lower
and upper
are arguments of the algorithm, not of the objective, which means they're not part of the optimality conditions. So, we can't use implicit differentiation. Unrolling will likely not work either due to discontinuous operations.
Not sure if it's applicable here but an alternative would be to use stop_gradient
(see example here).
If you agree with me, we can relabel this issue as documentation. Adding a short paragraph on this would be helpful.
Yes, I see the issue. This would be good to know. Thanks!
Sorry for the delayed reply. In the example above making spin
static in jit is not a good idea for me, since I do need this to work at many different spin
's.
The issue here does not really prevent me from computing what I want, but it does make my code ugly. I need to have a jitted version and a unjitted version for each function, rather than just jit everything at definition.
I'm not really an expert on jax.jit
, but technically is it possible to get some pointer towards foo
from jax.jit(foo)
? If yes, then I think there should be a way to make jax.jit(jax.grad(jax.jit(r_c_solve)))
work similar to jax.jit(jax.grad(r_c_solve))
, i.e. just let it use the underlying unjitted function instead of the jitted one.
@mblondel not sure if I understand your comment regarding stop_gradient
. Nothing changes if I do lower=jax.lax.stop_gradient(r_0), upper=jax.lax.stop_gradient(r_1)
in my snippet.