diffrax
diffrax copied to clipboard
Error When I use TPU on Google Colab.
I try to use TPU on Colab with Diffrax. Error appears when I use TPU but no problem with CPU or GPU.
import jax.tools.colab_tpu
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController
jax.tools.colab_tpu.setup_tpu()
vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller)
print(sol.ts)
print(sol.ys)
Error throws
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
[/usr/lib/python3.7/runpy.py](https://localhost:8080/#) in _run_module_as_main(***failed resolving arguments***)
192 return _run_code(code, main_globals, None,
--> 193 "__main__", mod_spec)
194
59 frames
JaxStackTraceBeforeTransformation: AttributeError: 'NoneType' object has no attribute 'add_outfeed'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
UnfilteredStackTrace Traceback (most recent call last)
UnfilteredStackTrace: AttributeError: 'NoneType' object has no attribute 'add_outfeed'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _outside_call_translation_rule(ctx, avals_in, avals_out, has_token, identity, flat_results_aval, *args_op, **params)
1024 flat_results_aval=flat_results_aval,
1025 **params))
-> 1026 next_token = _callback_handler_data.receiver.add_outfeed(comp, current_token,
1027 callback_id,
1028 args_to_outfeed)
AttributeError: 'NoneType' object has no attribute 'add_outfeed'
I'm not sure it's a colab TPU specific problem or Jax but I find some people got similar issue with Jax. https://github.com/google/jax/issues/9053
here's my Colab https://colab.research.google.com/drive/1ZMUrgbelBwngHwiD-a6eeEWhdwRZJsBh?usp=sharing
Anyway the work you are doing is amazing.
Hmm. This looks like a JAX problem, rather than a Diffrax one. Namely that host callbacks and TPUs aren't interacting properly for some users.
You can probably work around this by disabling Diffrax's use of host callbacks. Try putting this at the top of your code, before you import anything else:
import diffrax.misc.errors
diffrax.misc.errors.branched_error_if = lambda *a, **kw: None
This is monkey-patching this function.
Diffrax only uses host callback for one thing -- raising errors on malformed inputs, e.g. when t1 > t0
but dt0 < 0
. The monkey-patching above will disable the use of host callback, but will of course also mean that errors may silently bite you instead.
I recently hit a similar issue. Just correcting my previous comment to give a better fix:
import sys
for module_name, module in sys.modules.items():
if module_name.startswith("diffrax"):
if hasattr(module, "branched_error_if"):
module.branched_error_if = lambda *a, **kw: None
That works for me thank @patrick-kidger
Closing as I think this should be resolved now. We no longer use the old host_callback
mechanism, and equinox.internal.branched_error_if
should now work robustly on all backends.