jax icon indicating copy to clipboard operation
jax copied to clipboard

Compilation hangs indefinitely on GPU

Open dionhaefner opened this issue 4 years ago • 9 comments

I am encountering an issue where compilation on GPU hangs forever in a semi-deterministic way (happens every time, but at slightly different places). All functions have been compiled successfully before (but with different shapes).

This happens in the middle of a huge model code, and I unfortunately haven't been able to come up with a reproducer. After 2 minutes I get the "slow compile" warning, then all I can do is send SIGKILL.

I have dumped the HLO but it looks inconspicuous to me:

https://gist.github.com/dionhaefner/e5680e131975b6bf566c1e1cbc554476

The only lead I have is that right before it hangs, I do something like this:

# <do computations on GPU with JAX>

import numpy as onp
rhs = onp.asarray(rhs)
x0 = onp.asarray(x0)

linear_solution, info = scipy.sparse.linalg.bicgstab(
    _matrix,
    rhs,
    x0=x0,
    atol=0,
    tol=settings.congr_epsilon,
    maxiter=settings.congr_max_iterations,
    **self._extra_args,
)

return jnp.asarray(linear_solution)

# a couple of lines later everything hangs

If I comment out the BiCG solver everything works.

This happens on JAX built from source and current wheels. Downgrading jaxlib did not help either. Works on jaxlib 0.1.64, albeit poorly (factor of 10 slower for some reason).

If you have any pointer on how to debug this I would be grateful.

dionhaefner avatar May 24 '21 17:05 dionhaefner

Your HLO dump only includes a single module. Are you sure that's the module that is hanging? (It doesn't appear to hang for me, and it looks unremarkable.)

hawkinsp avatar May 24 '21 19:05 hawkinsp

Yes, it's this one (according to the slow compilation warning).

The function is unremarkable, and sometimes compilation succeeds (but then the compiler hangs on the next module). So I doubt this has anything to do with it.

To me this sounds like a race condition or something running out of resources.

Anyhow, I was hoping there would be some more debug flags I could use to understand what's going on internally. If that's not the case I will have to get back to you when I can narrow it down a bit.

dionhaefner avatar May 24 '21 20:05 dionhaefner

Is this a multi-GPU machine? @skye was trying to track down a similar-looking hang that exhibits on multi GPU machines.

hawkinsp avatar May 24 '21 21:05 hawkinsp

Is this a multi-GPU machine?

Yes, it is!

After you mentioned that I hoped that something like CUDA_VISIBLE_DEVICES=0 would fix it, but that doesn't seem to be the case...

dionhaefner avatar May 25 '21 07:05 dionhaefner

Ok, I think we'll need a way to reproduce this. If you share all of the HLO dumps, it's possible we can reproduce from that, although a Python repro would be best.

hawkinsp avatar May 25 '21 12:05 hawkinsp

I don't have a self-contained example, sorry.

This is how I can reproduce it reliably:

$ git clone https://github.com/team-ocean/veros.git
$ cd veros
$ git checkout 6d7833a1801c028048f4d7cf6b80e86bec7f3224
$ pip install -e .
$ pip install jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ export VEROS_LINEAR_SOLVER=scipy
$ python benchmarks/acc_benchmark.py --size 304 304 108 --timesteps 10 -b jax --device gpu

Importing core modules
 Using computational backend jax on gpu
  Kernels are compiled during first iteration, be patient
 Runtime settings are now locked

Running model setup

# lots of output, until...

Initializing linear solver
Computing ILU preconditioner...
 Solving for boundary contribution by island 0
2021-05-26 13:05:29.319948: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
********************************
Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Compiling module primitive_computation_scatter__8.10
********************************

Platform info:

$ nvidia-smi
Wed May 26 13:04:13 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:03:00.0 Off |                    0 |
| N/A   25C    P0    31W / 250W |  14690MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000000:04:00.0 Off |                    0 |
| N/A   25C    P0    32W / 250W |    257MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

It would be valuable to know whether someone else can reproduce this on different hardware (we only have this one GPU machine). But I also understand that you can't debug my code for me, so I'm putting this out there mostly as a reminder to self.

dionhaefner avatar May 26 '21 11:05 dionhaefner

This sounds related: https://github.com/google/jax/issues/6471

In both cases, we get hanging compilations on GPU after calling scipy.sparse.linalg.bicgstab. Would be a hell of a coincidence...

dionhaefner avatar May 26 '21 16:05 dionhaefner

FWIW, this does not occur when I do

$ export OMP_NUM_THREADS=1

Could this be SciPy's internal OpenMP parallelization clashing with JAX's thread parallelism?

dionhaefner avatar Jun 15 '21 08:06 dionhaefner

I encountered the same problem when I installed brax with pip install brax and run training with learn --learner=ppo --env=humanoid on a CentOS server with 4 gpus. The output is:

I1029 07:57:57.650648 140016890124032 logging_writer.py:57] Hyperparameters: {'log_frequency': 10, 'num_envs': 4, 'total_env_steps': 50000000}
I1029 07:57:57.664291 140026811004672 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I1029 07:57:57.674494 140026811004672 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I1029 07:57:57.674688 140026811004672 ppo.py:218] Device count: 4, process count: 1 (id 0), local device count: 4, devices to be used count: 4
2021-10-29 07:59:58.028158: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
********************************
Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Compiling module primitive_computation_convert_element_type.3
********************************

The HLO dump is attached: xla_slow.tar.gz I tried the same training params on another Ubuntu server with only 1 gpu and there was no hanging. But if I run training on the CentOS server with CUDA_VISIBLE_DEVICES=0, it's still stuck. Then I tried run training with $ export OMP_NUM_THREADS=1, it's no longer stuck.

chaihahaha avatar Oct 29 '21 04:10 chaihahaha

@dionhaefner Was this resolved? Do you still need help?

sudhakarsingh27 avatar Aug 15 '22 18:08 sudhakarsingh27

Seems fixed with recent JAX, thanks!

dionhaefner avatar Aug 15 '22 19:08 dionhaefner