jax icon indicating copy to clipboard operation
jax copied to clipboard

ptxas error : Entry function 'fusion_##' uses too much shared data

Open sdatkinson opened this issue 5 years ago • 6 comments

I'm running into an error when I attempt to JIT a rather involved gradient when using JAX with x64 enabled.

The error I get is:

2020-05-17 13:01:34.512390: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:414] Error during compilation of ptx to sass: Internal: ptxas exited with non-zero error code 65280, output: ptxas error   : Entry function 'fusion_42' uses too much shared data (0xc600 bytes, 0xc000 max)

2020-05-17 13:01:37.251877: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:576] failed to load PTX text as a module: CUDA_ERROR_INVALID_PTX: a PTX JIT compilation failed
2020-05-17 13:01:37.251910: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:581] error log buffer (95 bytes): ptxas error   : Entry function 'fusion_42' uses too much shared data (0xc600 bytes, 0xc000 max
2020-05-17 13:01:37.252001: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_client.cc:1484] Execution of replica 0 failed: Internal: Failed to load PTX text as a module: CUDA_ERROR_INVALID_PTX: a PTX JIT compilation failed
  0%|                                                                                                                                                                          | 0/50000 [00:10<?, ?it/s]
Traceback (most recent call last):
  File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/steven/.vscode/extensions/ms-python.python-2020.5.78807/pythonFiles/lib/python/debugpy/wheels/debugpy/__main__.py", line 45, in <module>
    cli.main()
  File "/home/steven/.vscode/extensions/ms-python.python-2020.5.78807/pythonFiles/lib/python/debugpy/wheels/debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/steven/.vscode/extensions/ms-python.python-2020.5.78807/pythonFiles/lib/python/debugpy/wheels/debugpy/../debugpy/server/cli.py", line 267, in run_file
    runpy.run_path(options.target, run_name=compat.force_str("__main__"))
  File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/steven/src/file1.py", line 191, in <module>
    main()
  File "/home/steven/src/file1.py", line 165, in main
    rng_train, data, train_settings, testing
  File "/home/steven/src/file2.py", line 90, in train
    opt_state = step(i, rng_step, opt_state, data)
  File "/home/steven/src/venv-gpu/lib/python3.7/site-packages/jax/api.py", line 154, in f_jitted
    name=flat_fun.__name__)
  File "/home/steven/src/venv-gpu/lib/python3.7/site-packages/jax/core.py", line 1021, in _call_bind
    outs = primitive.impl(f, *args, **params)
  File "/home/steven/src/venv-gpu/lib/python3.7/site-packages/jax/interpreters/xla.py", line 523, in _xla_call_impl
    return compiled_fun(*args)
  File "/home/steven/src/venv-gpu/lib/python3.7/site-packages/jax/interpreters/xla.py", line 632, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: Internal: Failed to load PTX text as a module: CUDA_ERROR_INVALID_PTX: a PTX JIT compilation failed

On one hand, I've traced the problematic operation back to a classic scaled square distance computation, i.e.:

def _squared_distance(x1, x2, scales=None):
    """
    :param x1: np.array, (N,D)
    :param x1: np.array, (M,D)
    :return: np.array, (N,M)
    """
    # Error appears when scales is not None
    z1, z2 = (x1, x2) if scales is None else (x1 / scales, x2 / scales)
    return (
        np.sum(z1 * z1, axis=1, keepdims=True)
        - 2.0 * z1 @ z2.T
        + np.sum(z2 * z2, axis=1, keepdims=True).T
    )

The error appears once both N and M are 32 or larger (31 is ok). However, the bug seems to require other bits (that I haven't nailed down yet, regretfully) leading into this function to also be sufficiently complex. The hints I've gathered so far are:

  • x2 is the 3rd derivative of a neural net. 2nd derivative doesn't cause the error to appear, and higher-order derivatives (tested up to 10) are fine when x64 isn't enabled.
  • There is a grad over the whole function that this is part of. This also seems to be necessary.

Perhaps this is related to #677?

My questions are:

  • Do we have any ideas why this happens? It seems like an XLA issue?
  • Thinking about possible workarounds, is there some way I can give the XLA compiler a hint not to fuse certain lines? I think that the problematic line should be negligible cost overall and suspect that this would be a less intrusive fix if such a thing exists.

Also, if you have any hunches based on past experiences that might help me create my MWE, I'd love to hear them so I can help get to the bottom of this.

Additional setup info (can add more info as needed):

  • GPU is an RTX 2070, CUDA is 10.1
  • jax version is 0.1.67, jaxlib is 0.1.47, installed via pip.

Thanks in advance for you attention, and thanks for developing this wonderful software!

sdatkinson avatar May 17 '20 17:05 sdatkinson

Thanks for the beautiful and careful writeup, even for a pretty slippery-sounding bug! I pinged the XLA:GPU folks for any immediate gut reactions. Indeed #677 looks pretty similar.

Will report back with what we figure out!

mattjj avatar May 17 '20 18:05 mattjj

One of our good friends on XLA:CPU/GPU suggests "it means that it's most likely an XLA bug and needs tweaking of this heuristic." Though I believe that's a gut reaction and subject to change upon further investigation.

mattjj avatar May 17 '20 20:05 mattjj

One way to prohibit fusion (although it's kind of a big hammer) is to use jax.remat (intended for rematerialization/checkpointing of gradients, but with the relevant side meaning of "semantically the identity, but compilers can't look inside"). If you're still blocked on this, maybe give that a try?

jekbradbury avatar May 19 '20 02:05 jekbradbury

There's a slow-moving discussion in the XLA chat room, though it's not clear yet if anyone has the bandwidth to pick this up right now. I'll keep poking, but yeah it'd be interesting if jax.remat helped as a workaround.

mattjj avatar May 19 '20 14:05 mattjj

Thanks for the input so far.

I haven't had the time to try with jax.remat yet but found that implementing the function above using a pair of jax.vmaps like the GP example gets rid of the issue, though it seems to be a little less efficient. Fortunately, this isn't a computational bottleneck so it's good enough to be workable for now.

I'll report back when I've had a chance to try w/ remat.

sdatkinson avatar May 20 '20 00:05 sdatkinson

Interesting! Indeed it can even be asymptotically less efficient without the matmul there, but good to know there's a workaround. Will keep you posted with updates...

mattjj avatar May 20 '20 03:05 mattjj

Do you still have this bug? I know such a bug was fixed. If you still have it, can you execute with that environment flag and dump the stdout/stderr to a file and send the file? This will allow us to reproduce your issue.

XLA_FLAGS="--xla_dump_hlo_as_text"

nouiz avatar Sep 14 '22 19:09 nouiz

As no news and I think it is fixed, I'll close it. If you still see it, just reopen this bug.

nouiz avatar Sep 21 '22 19:09 nouiz