jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax.pure_callback` crashes on TPU VM

Open nalzok opened this issue 3 years ago • 1 comments

Description

This is a follow-up issue of discussion https://github.com/google/jax/discussions/12245. The solution suggested by a project collaborator crashed the Python interpreter on my TPU VM, but it should work cross-platform.

Python 3.8.10 (default, Jun 22 2022, 20:18:18)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import jax.numpy as jnp
>>> import scipy.linalg
>>>
>>> def schur(x):
...   return jax.pure_callback(scipy.linalg.schur, (x, x), x)
...
>>> @jax.jit
... def f(x):
...   return schur(x)
...
>>> print(f(jnp.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]], jnp.float32)))
F0907 14:29:33.631891  939012 host_command_dispatcher.cc:83] Check failed: !handlers_by_run_ids_[queue_id].empty() Host command 50331929 triggered but no handler was registered, run id: 1
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/api.py", line 528, in cache_miss
    out_flat = xla.xla_call(
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 1963, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 1979, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 689, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl
    return compiled_fun(*args)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 841, in _execute_compiled
    out_bufs = token_handler(out_bufs, runtime_token)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 808, in _remove_tokens
    output_token_buf, *token_bufs = token_bufs
jax._src.traceback_util.UnfilteredStackTrace: ValueError: not enough values to unpack (expected at least 1, got 0)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: not enough values to unpack (expected at least 1, got 0)
>>> *** Check failure stack trace: ***
    @     0x7f07a33b98e4  (unknown)
    @     0x7f07a33b93ca  (unknown)
    @     0x7f07a33b9c49  (unknown)
    @     0x7f079ff89e79  (unknown)
    @     0x7f079ff70f66  (unknown)
    @     0x7f079dfa17c3  (unknown)
    @     0x7f07a3297669  (unknown)
    @     0x7f07a32930fe  (unknown)
    @     0x7f0e8753b609  start_thread
https://symbolize.stripped_domain/r/?trace=7f07a33b98e4,7f07a33b93c9,7f07a33b9c48,7f079ff89e78,7f079ff70f65,7f079dfa17c2,7f07a3297668,7f07a32930fd,7f0e8753b608&map=068bf80b76f830987166dd8847d0248f:7f078dddc000-7f07a370ede0
https://symbolize.stripped_domain/r/?trace=7f0e8759900b,7f0e8759908f,7f07a33b9943,7f07a33b93c9,7f07a33b9c48,7f079ff89e78,7f079ff70f65,7f079dfa17c2,7f07a3297668,7f07a32930fd,7f0e8753b608&map=068bf80b76f830987166dd8847d0248f:7f078dddc000-7f07a370ede0
*** SIGABRT received by PID 937881 (TID 939012) on cpu 81 from PID 937881; ***
E0907 14:29:33.636138  939012 coredump_hook.cc:370] RAW: Remote crash data gathering hook invoked.
E0907 14:29:33.636153  939012 coredump_hook.cc:416] RAW: Skipping coredump since rlimit was 0 at process start.
E0907 14:29:33.636161  939012 client.cc:242] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0907 14:29:33.636166  939012 coredump_hook.cc:477] RAW: Sending fingerprint to remote end.
E0907 14:29:33.636174  939012 coredump_socket.cc:118] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0907 14:29:33.636184  939012 coredump_hook.cc:481] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0907 14:29:33.636190  939012 coredump_hook.cc:555] RAW: Discarding core.
F0907 14:29:33.631891  939012 host_command_dispatcher.cc:83] Check failed: !handlers_by_run_ids_[queue_id].empty() Host command 50331929 triggered but no handler was registered, run id: 1
E0907 14:29:33.898010  939012 process_state.cc:774] RAW: Raising signal 6 with default behavior
fish: “pipenv run python3” terminated by signal SIGABRT (Abort)

What jax/jaxlib version are you using?

jax v0.3.17, jaxlab v0.3.15

Which accelerator(s) are you using?

TPU v3-8 with libtpu v1.3.0

Additional System Info

Python 3.8.10, TPU VM on GCP running Ubuntu 20.04 (Linux t1v-n-e307e167-w-0 5.13.0-1023-gcp #28~20.04.1-Ubuntu SMP Wed Mar 30 03:51:07 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux)

nalzok avatar Sep 07 '22 19:09 nalzok

Seems like callbacks do not work right now on Cloud TPU VM because they are using the older stream_executor runtime but will soon switch to a newer runtime that does support callbacks. I'll monitor and update the issue when callbacks work.

sharadmv avatar Sep 08 '22 17:09 sharadmv

@sharadmv any movement on this by any chance? Is there maybe an experimental vm image we can use? Thanks!

dlwh avatar Apr 17 '23 20:04 dlwh

I think they should work now. Cc: @skye

sharadmv avatar Apr 17 '23 20:04 sharadmv

Ah yeah, this should work on Cloud TPU as of jax 0.4.8. I'm gonna close this issue, but please comment or reopen if you find things still aren't working!

skye avatar Apr 17 '23 21:04 skye

oops, yeah! i was running an older version of jax still. Thank you!

dlwh avatar Apr 17 '23 21:04 dlwh