jax
jax copied to clipboard
ptxas error : Entry function 'fusion_##' uses too much shared data
I'm running into an error when I attempt to JIT a rather involved gradient when using JAX with x64 enabled.
The error I get is:
2020-05-17 13:01:34.512390: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:414] Error during compilation of ptx to sass: Internal: ptxas exited with non-zero error code 65280, output: ptxas error : Entry function 'fusion_42' uses too much shared data (0xc600 bytes, 0xc000 max)
2020-05-17 13:01:37.251877: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:576] failed to load PTX text as a module: CUDA_ERROR_INVALID_PTX: a PTX JIT compilation failed
2020-05-17 13:01:37.251910: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:581] error log buffer (95 bytes): ptxas error : Entry function 'fusion_42' uses too much shared data (0xc600 bytes, 0xc000 max
2020-05-17 13:01:37.252001: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_client.cc:1484] Execution of replica 0 failed: Internal: Failed to load PTX text as a module: CUDA_ERROR_INVALID_PTX: a PTX JIT compilation failed
0%| | 0/50000 [00:10<?, ?it/s]
Traceback (most recent call last):
File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/steven/.vscode/extensions/ms-python.python-2020.5.78807/pythonFiles/lib/python/debugpy/wheels/debugpy/__main__.py", line 45, in <module>
cli.main()
File "/home/steven/.vscode/extensions/ms-python.python-2020.5.78807/pythonFiles/lib/python/debugpy/wheels/debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/steven/.vscode/extensions/ms-python.python-2020.5.78807/pythonFiles/lib/python/debugpy/wheels/debugpy/../debugpy/server/cli.py", line 267, in run_file
runpy.run_path(options.target, run_name=compat.force_str("__main__"))
File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/home/steven/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/steven/src/file1.py", line 191, in <module>
main()
File "/home/steven/src/file1.py", line 165, in main
rng_train, data, train_settings, testing
File "/home/steven/src/file2.py", line 90, in train
opt_state = step(i, rng_step, opt_state, data)
File "/home/steven/src/venv-gpu/lib/python3.7/site-packages/jax/api.py", line 154, in f_jitted
name=flat_fun.__name__)
File "/home/steven/src/venv-gpu/lib/python3.7/site-packages/jax/core.py", line 1021, in _call_bind
outs = primitive.impl(f, *args, **params)
File "/home/steven/src/venv-gpu/lib/python3.7/site-packages/jax/interpreters/xla.py", line 523, in _xla_call_impl
return compiled_fun(*args)
File "/home/steven/src/venv-gpu/lib/python3.7/site-packages/jax/interpreters/xla.py", line 632, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: Internal: Failed to load PTX text as a module: CUDA_ERROR_INVALID_PTX: a PTX JIT compilation failed
On one hand, I've traced the problematic operation back to a classic scaled square distance computation, i.e.:
def _squared_distance(x1, x2, scales=None):
"""
:param x1: np.array, (N,D)
:param x1: np.array, (M,D)
:return: np.array, (N,M)
"""
# Error appears when scales is not None
z1, z2 = (x1, x2) if scales is None else (x1 / scales, x2 / scales)
return (
np.sum(z1 * z1, axis=1, keepdims=True)
- 2.0 * z1 @ z2.T
+ np.sum(z2 * z2, axis=1, keepdims=True).T
)
The error appears once both N and M are 32 or larger (31 is ok).
However, the bug seems to require other bits (that I haven't nailed down yet, regretfully) leading into this function to also be sufficiently complex. The hints I've gathered so far are:
x2is the 3rd derivative of a neural net. 2nd derivative doesn't cause the error to appear, and higher-order derivatives (tested up to 10) are fine when x64 isn't enabled.- There is a
gradover the whole function that this is part of. This also seems to be necessary.
Perhaps this is related to #677?
My questions are:
- Do we have any ideas why this happens? It seems like an XLA issue?
- Thinking about possible workarounds, is there some way I can give the XLA compiler a hint not to fuse certain lines? I think that the problematic line should be negligible cost overall and suspect that this would be a less intrusive fix if such a thing exists.
Also, if you have any hunches based on past experiences that might help me create my MWE, I'd love to hear them so I can help get to the bottom of this.
Additional setup info (can add more info as needed):
- GPU is an RTX 2070, CUDA is 10.1
- jax version is 0.1.67, jaxlib is 0.1.47, installed via pip.
Thanks in advance for you attention, and thanks for developing this wonderful software!
Thanks for the beautiful and careful writeup, even for a pretty slippery-sounding bug! I pinged the XLA:GPU folks for any immediate gut reactions. Indeed #677 looks pretty similar.
Will report back with what we figure out!
One of our good friends on XLA:CPU/GPU suggests "it means that it's most likely an XLA bug and needs tweaking of this heuristic." Though I believe that's a gut reaction and subject to change upon further investigation.
One way to prohibit fusion (although it's kind of a big hammer) is to use jax.remat (intended for rematerialization/checkpointing of gradients, but with the relevant side meaning of "semantically the identity, but compilers can't look inside"). If you're still blocked on this, maybe give that a try?
There's a slow-moving discussion in the XLA chat room, though it's not clear yet if anyone has the bandwidth to pick this up right now. I'll keep poking, but yeah it'd be interesting if jax.remat helped as a workaround.
Thanks for the input so far.
I haven't had the time to try with jax.remat yet but found that implementing the function above using a pair of jax.vmaps like the GP example gets rid of the issue, though it seems to be a little less efficient. Fortunately, this isn't a computational bottleneck so it's good enough to be workable for now.
I'll report back when I've had a chance to try w/ remat.
Interesting! Indeed it can even be asymptotically less efficient without the matmul there, but good to know there's a workaround. Will keep you posted with updates...
Do you still have this bug? I know such a bug was fixed. If you still have it, can you execute with that environment flag and dump the stdout/stderr to a file and send the file? This will allow us to reproduce your issue.
XLA_FLAGS="--xla_dump_hlo_as_text"
As no news and I think it is fixed, I'll close it. If you still see it, just reopen this bug.