jax
jax copied to clipboard
Compilation hangs indefinitely on GPU
I am encountering an issue where compilation on GPU hangs forever in a semi-deterministic way (happens every time, but at slightly different places). All functions have been compiled successfully before (but with different shapes).
This happens in the middle of a huge model code, and I unfortunately haven't been able to come up with a reproducer. After 2 minutes I get the "slow compile" warning, then all I can do is send SIGKILL.
I have dumped the HLO but it looks inconspicuous to me:
https://gist.github.com/dionhaefner/e5680e131975b6bf566c1e1cbc554476
The only lead I have is that right before it hangs, I do something like this:
# <do computations on GPU with JAX>
import numpy as onp
rhs = onp.asarray(rhs)
x0 = onp.asarray(x0)
linear_solution, info = scipy.sparse.linalg.bicgstab(
_matrix,
rhs,
x0=x0,
atol=0,
tol=settings.congr_epsilon,
maxiter=settings.congr_max_iterations,
**self._extra_args,
)
return jnp.asarray(linear_solution)
# a couple of lines later everything hangs
If I comment out the BiCG solver everything works.
This happens on JAX built from source and current wheels. Downgrading jaxlib did not help either. Works on jaxlib 0.1.64, albeit poorly (factor of 10 slower for some reason).
If you have any pointer on how to debug this I would be grateful.
Your HLO dump only includes a single module. Are you sure that's the module that is hanging? (It doesn't appear to hang for me, and it looks unremarkable.)
Yes, it's this one (according to the slow compilation warning).
The function is unremarkable, and sometimes compilation succeeds (but then the compiler hangs on the next module). So I doubt this has anything to do with it.
To me this sounds like a race condition or something running out of resources.
Anyhow, I was hoping there would be some more debug flags I could use to understand what's going on internally. If that's not the case I will have to get back to you when I can narrow it down a bit.
Is this a multi-GPU machine? @skye was trying to track down a similar-looking hang that exhibits on multi GPU machines.
Is this a multi-GPU machine?
Yes, it is!
After you mentioned that I hoped that something like CUDA_VISIBLE_DEVICES=0 would fix it, but that doesn't seem to be the case...
Ok, I think we'll need a way to reproduce this. If you share all of the HLO dumps, it's possible we can reproduce from that, although a Python repro would be best.
I don't have a self-contained example, sorry.
This is how I can reproduce it reliably:
$ git clone https://github.com/team-ocean/veros.git
$ cd veros
$ git checkout 6d7833a1801c028048f4d7cf6b80e86bec7f3224
$ pip install -e .
$ pip install jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ export VEROS_LINEAR_SOLVER=scipy
$ python benchmarks/acc_benchmark.py --size 304 304 108 --timesteps 10 -b jax --device gpu
Importing core modules
Using computational backend jax on gpu
Kernels are compiled during first iteration, be patient
Runtime settings are now locked
Running model setup
# lots of output, until...
Initializing linear solver
Computing ILU preconditioner...
Solving for boundary contribution by island 0
2021-05-26 13:05:29.319948: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
********************************
Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Compiling module primitive_computation_scatter__8.10
********************************
Platform info:
$ nvidia-smi
Wed May 26 13:04:13 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla P100-PCIE... Off | 00000000:03:00.0 Off | 0 |
| N/A 25C P0 31W / 250W | 14690MiB / 16280MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 Tesla P100-PCIE... Off | 00000000:04:00.0 Off | 0 |
| N/A 25C P0 32W / 250W | 257MiB / 16280MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
It would be valuable to know whether someone else can reproduce this on different hardware (we only have this one GPU machine). But I also understand that you can't debug my code for me, so I'm putting this out there mostly as a reminder to self.
This sounds related: https://github.com/google/jax/issues/6471
In both cases, we get hanging compilations on GPU after calling scipy.sparse.linalg.bicgstab. Would be a hell of a coincidence...
FWIW, this does not occur when I do
$ export OMP_NUM_THREADS=1
Could this be SciPy's internal OpenMP parallelization clashing with JAX's thread parallelism?
I encountered the same problem when I installed brax with pip install brax and run training with learn --learner=ppo --env=humanoid on a CentOS server with 4 gpus.
The output is:
I1029 07:57:57.650648 140016890124032 logging_writer.py:57] Hyperparameters: {'log_frequency': 10, 'num_envs': 4, 'total_env_steps': 50000000}
I1029 07:57:57.664291 140026811004672 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I1029 07:57:57.674494 140026811004672 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I1029 07:57:57.674688 140026811004672 ppo.py:218] Device count: 4, process count: 1 (id 0), local device count: 4, devices to be used count: 4
2021-10-29 07:59:58.028158: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
********************************
Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Compiling module primitive_computation_convert_element_type.3
********************************
The HLO dump is attached:
xla_slow.tar.gz
I tried the same training params on another Ubuntu server with only 1 gpu and there was no hanging. But if I run training on the CentOS server with CUDA_VISIBLE_DEVICES=0, it's still stuck.
Then I tried run training with $ export OMP_NUM_THREADS=1, it's no longer stuck.
@dionhaefner Was this resolved? Do you still need help?
Seems fixed with recent JAX, thanks!