diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Error When I use TPU on Google Colab.

Open tentechit opened this issue 1 year ago • 3 comments

I try to use TPU on Colab with Diffrax. Error appears when I use TPU but no problem with CPU or GPU.

import jax.tools.colab_tpu
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

jax.tools.colab_tpu.setup_tpu()

vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
                  stepsize_controller=stepsize_controller)

print(sol.ts)  
print(sol.ys) 

Error throws

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
[/usr/lib/python3.7/runpy.py](https://localhost:8080/#) in _run_module_as_main(***failed resolving arguments***)
    192     return _run_code(code, main_globals, None,
--> 193                      "__main__", mod_spec)
    194 

59 frames
JaxStackTraceBeforeTransformation: AttributeError: 'NoneType' object has no attribute 'add_outfeed'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

UnfilteredStackTrace                      Traceback (most recent call last)
UnfilteredStackTrace: AttributeError: 'NoneType' object has no attribute 'add_outfeed'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _outside_call_translation_rule(ctx, avals_in, avals_out, has_token, identity, flat_results_aval, *args_op, **params)
   1024             flat_results_aval=flat_results_aval,
   1025             **params))
-> 1026     next_token = _callback_handler_data.receiver.add_outfeed(comp, current_token,
   1027                                                         callback_id,
   1028                                                         args_to_outfeed)

AttributeError: 'NoneType' object has no attribute 'add_outfeed'

I'm not sure it's a colab TPU specific problem or Jax but I find some people got similar issue with Jax. https://github.com/google/jax/issues/9053

here's my Colab https://colab.research.google.com/drive/1ZMUrgbelBwngHwiD-a6eeEWhdwRZJsBh?usp=sharing

Anyway the work you are doing is amazing.

tentechit avatar Jul 15 '22 13:07 tentechit

Hmm. This looks like a JAX problem, rather than a Diffrax one. Namely that host callbacks and TPUs aren't interacting properly for some users.

You can probably work around this by disabling Diffrax's use of host callbacks. Try putting this at the top of your code, before you import anything else:

import diffrax.misc.errors
diffrax.misc.errors.branched_error_if = lambda *a, **kw: None

This is monkey-patching this function.

Diffrax only uses host callback for one thing -- raising errors on malformed inputs, e.g. when t1 > t0 but dt0 < 0. The monkey-patching above will disable the use of host callback, but will of course also mean that errors may silently bite you instead.

patrick-kidger avatar Jul 15 '22 13:07 patrick-kidger

I recently hit a similar issue. Just correcting my previous comment to give a better fix:

import sys
for module_name, module in sys.modules.items():
  if module_name.startswith("diffrax"):
    if hasattr(module, "branched_error_if"):
      module.branched_error_if = lambda *a, **kw: None

patrick-kidger avatar Jul 22 '22 16:07 patrick-kidger

That works for me thank @patrick-kidger

tentechit avatar Jul 23 '22 09:07 tentechit

Closing as I think this should be resolved now. We no longer use the old host_callback mechanism, and equinox.internal.branched_error_if should now work robustly on all backends.

patrick-kidger avatar May 25 '23 18:05 patrick-kidger