jax
jax copied to clipboard
`jax.pure_callback` crashes on TPU VM
Description
This is a follow-up issue of discussion https://github.com/google/jax/discussions/12245. The solution suggested by a project collaborator crashed the Python interpreter on my TPU VM, but it should work cross-platform.
Python 3.8.10 (default, Jun 22 2022, 20:18:18)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import jax.numpy as jnp
>>> import scipy.linalg
>>>
>>> def schur(x):
... return jax.pure_callback(scipy.linalg.schur, (x, x), x)
...
>>> @jax.jit
... def f(x):
... return schur(x)
...
>>> print(f(jnp.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]], jnp.float32)))
F0907 14:29:33.631891 939012 host_command_dispatcher.cc:83] Check failed: !handlers_by_run_ids_[queue_id].empty() Host command 50331929 triggered but no handler was registered, run id: 1
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/api.py", line 528, in cache_miss
out_flat = xla.xla_call(
File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 1963, in bind
return call_bind(self, fun, *args, **params)
File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 1979, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 689, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl
return compiled_fun(*args)
File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 841, in _execute_compiled
out_bufs = token_handler(out_bufs, runtime_token)
File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 808, in _remove_tokens
output_token_buf, *token_bufs = token_bufs
jax._src.traceback_util.UnfilteredStackTrace: ValueError: not enough values to unpack (expected at least 1, got 0)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: not enough values to unpack (expected at least 1, got 0)
>>> *** Check failure stack trace: ***
@ 0x7f07a33b98e4 (unknown)
@ 0x7f07a33b93ca (unknown)
@ 0x7f07a33b9c49 (unknown)
@ 0x7f079ff89e79 (unknown)
@ 0x7f079ff70f66 (unknown)
@ 0x7f079dfa17c3 (unknown)
@ 0x7f07a3297669 (unknown)
@ 0x7f07a32930fe (unknown)
@ 0x7f0e8753b609 start_thread
https://symbolize.stripped_domain/r/?trace=7f07a33b98e4,7f07a33b93c9,7f07a33b9c48,7f079ff89e78,7f079ff70f65,7f079dfa17c2,7f07a3297668,7f07a32930fd,7f0e8753b608&map=068bf80b76f830987166dd8847d0248f:7f078dddc000-7f07a370ede0
https://symbolize.stripped_domain/r/?trace=7f0e8759900b,7f0e8759908f,7f07a33b9943,7f07a33b93c9,7f07a33b9c48,7f079ff89e78,7f079ff70f65,7f079dfa17c2,7f07a3297668,7f07a32930fd,7f0e8753b608&map=068bf80b76f830987166dd8847d0248f:7f078dddc000-7f07a370ede0
*** SIGABRT received by PID 937881 (TID 939012) on cpu 81 from PID 937881; ***
E0907 14:29:33.636138 939012 coredump_hook.cc:370] RAW: Remote crash data gathering hook invoked.
E0907 14:29:33.636153 939012 coredump_hook.cc:416] RAW: Skipping coredump since rlimit was 0 at process start.
E0907 14:29:33.636161 939012 client.cc:242] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0907 14:29:33.636166 939012 coredump_hook.cc:477] RAW: Sending fingerprint to remote end.
E0907 14:29:33.636174 939012 coredump_socket.cc:118] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0907 14:29:33.636184 939012 coredump_hook.cc:481] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0907 14:29:33.636190 939012 coredump_hook.cc:555] RAW: Discarding core.
F0907 14:29:33.631891 939012 host_command_dispatcher.cc:83] Check failed: !handlers_by_run_ids_[queue_id].empty() Host command 50331929 triggered but no handler was registered, run id: 1
E0907 14:29:33.898010 939012 process_state.cc:774] RAW: Raising signal 6 with default behavior
fish: “pipenv run python3” terminated by signal SIGABRT (Abort)
What jax/jaxlib version are you using?
jax v0.3.17, jaxlab v0.3.15
Which accelerator(s) are you using?
TPU v3-8 with libtpu v1.3.0
Additional System Info
Python 3.8.10, TPU VM on GCP running Ubuntu 20.04 (Linux t1v-n-e307e167-w-0 5.13.0-1023-gcp #28~20.04.1-Ubuntu SMP Wed Mar 30 03:51:07 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux)
Seems like callbacks do not work right now on Cloud TPU VM because they are using the older stream_executor runtime but will soon switch to a newer runtime that does support callbacks. I'll monitor and update the issue when callbacks work.
@sharadmv any movement on this by any chance? Is there maybe an experimental vm image we can use? Thanks!
I think they should work now. Cc: @skye
Ah yeah, this should work on Cloud TPU as of jax 0.4.8. I'm gonna close this issue, but please comment or reopen if you find things still aren't working!
oops, yeah! i was running an older version of jax still. Thank you!