jax
jax copied to clipboard
Make Pallas/GPU easier to install
Currently it's very difficult to install Pallas and jax_triton, since you have to get compatible versions of everything, and it's very finicky to work out which they are. We should make this easier!
I thought it would be useful to post here that the following steps got me a working installation on top of the ghcr.io/nvidia/jax:nightly-2023-12-08 container (from https://github.com/NVIDIA/JAX-Toolbox):
# pip install --no-deps 'jax-triton@git+https://github.com/jax-ml/jax-triton.git@test_588045313' # e4fd5cb21f40c3991a204479c3a1a0e3f0194e91
# cd /opt
# git clone -b llvm-head https://github.com/openai/triton.git # ca78acaf1a6cf68e2af8a68762ec852534ff0610
# cd triton/
# pip install -e python # quite slow
# cd /opt/jax
# git checkout test_588045313 # 56c46f36df6cfaa40253e839203f950481ce97cd
# build-jax.sh # seems to rebuild more than I expected, might be improvable
the test_588045313 branches refer to these PRs:
- https://github.com/google/jax/pull/18838
- https://github.com/jax-ml/jax-triton/pull/241
and IIUC we would prefer to build Triton from https://github.com/openxla/triton. Hopefully once those PRs land the recipe could be simplified a little.
Big ➕ to the sentiment of this issue: this should be easier!
Yes please!
+1! Started to play with https://github.com/google/maxtext but immediately ran into Pallas/Jax-Triton errors and after some effort haven't been able to get a working jax build with Pallas/jax-triton working.
+1 as well, I spent quite a while installing different permutations of versions of things but I couldn't find one that worked.
Another big +1 here, would really love to use Pallas but I cannot find a correct set of commands to install it correctly.
Hi everyone, I was able to get a set of mutually compatible version pins. I hope this unblocks folks temporarily while we push on the longer term solution (bundle Triton with jaxlib GPU).
Versions:
triton-nightly==2.1.0.post20231216005823
jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
jaxlib==0.4.24.dev20240103
jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c
Installation commands:
$ pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823
$ pip install -IU --pre jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$ pip install -IU --pre jaxlib[cuda12_pip]==0.4.24.dev20240103 -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
$ pip install --no-deps 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c'
I verified this worked with a simple kernel in colab but haven't yet run more extensive tests. Please let me know if these work for you.
Thanks @sharadmv! Appreciate the quick workaround.
Here is a copy pasteable version (note, you need to be using python 3.9 or 3.10):
pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823
pip install -IU --pre jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install -IU --pre "jaxlib[cuda12_pip]==0.4.24.dev20240103" -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
pip install --no-deps 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c'
EDIT: Though for me, this still didn't work locally, perhaps some interference with my local CUDA?
What's your error?
When running the Hello World example in the Quickstart. This is in a fresh environment, followed by the block above, with CUDA12.3.1-2 locally
2024-01-04 21:03:19.992205: W external/xla/xla/service/gpu/command_buffer_scheduling.cc:470] Removed command buffer support for CUBLAS as it's not supported with gpu toolkit version 12020 and driver version 12030. This might negatively impact peformance. To enable CUBLAS support in command buffers use cuda-compat package: https://docs.nvidia.com/deploy/cuda-compatibility/.
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File ~/.pyenv/versions/3.10.12/lib/python3.10/runpy.py:196, in _run_module_as_main()
195 sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
197 "__main__", mod_spec)
File ~/.pyenv/versions/3.10.12/lib/python3.10/runpy.py:86, in _run_code()
79 run_globals.update(__name__ = mod_name,
80 __file__ = fname,
81 __cached__ = cached,
(...)
84 __package__ = pkg_name,
85 __spec__ = mod_spec)
---> 86 exec(code, run_globals)
87 return run_globals
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel_launcher.py:17
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/traitlets/config/application.py:1075, in launch_instance()
1074 app.initialize(argv)
-> 1075 app.start()
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelapp.py:701, in start()
700 try:
--> 701 self.io_loop.start()
702 except KeyboardInterrupt:
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/tornado/platform/asyncio.py:205, in start()
204 def start(self) -> None:
--> 205 self.asyncio_loop.run_forever()
File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/base_events.py:603, in run_forever()
602 while True:
--> 603 self._run_once()
604 if self._stopping:
File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/base_events.py:1909, in _run_once()
1908 else:
-> 1909 handle._run()
1910 handle = None
File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/events.py:80, in _run()
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:534, in dispatch_queue()
533 try:
--> 534 await self.process_one()
535 except Exception:
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:523, in process_one()
522 return
--> 523 await dispatch(*args)
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:429, in dispatch_shell()
428 if inspect.isawaitable(result):
--> 429 await result
430 except Exception:
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:767, in execute_request()
766 if inspect.isawaitable(reply_content):
--> 767 reply_content = await reply_content
769 # Flush output before sending the reply.
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/ipkernel.py:429, in do_execute()
428 if accepts_params["cell_id"]:
--> 429 res = shell.run_cell(
430 code,
431 store_history=store_history,
432 silent=silent,
433 cell_id=cell_id,
434 )
435 else:
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/zmqshell.py:549, in run_cell()
548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3051, in run_cell()
3050 try:
-> 3051 result = self._run_cell(
3052 raw_cell, store_history, silent, shell_futures, cell_id
3053 )
3054 finally:
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3106, in _run_cell()
3105 try:
-> 3106 result = runner(coro)
3107 except BaseException as e:
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3311, in run_cell_async()
3308 interactivity = "none" if silent else self.ast_node_interactivity
-> 3311 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3312 interactivity=interactivity, compiler=compiler, result=result)
3314 self.last_execution_succeeded = not has_raised
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3493, in run_ast_nodes()
3492 asy = compare(code)
-> 3493 if await self.run_code(code, result, async_=asy):
3494 return True
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3553, in run_code()
3552 else:
-> 3553 exec(code_obj, self.user_global_ns, self.user_ns)
3554 finally:
3555 # Reset our crash handler in place
Cell In[4], line 1
----> 1 add_vectors(jnp.arange(8), jnp.arange(8))
Cell In[3], line 7, in add_vectors()
5 @jax.jit
6 def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
----> 7 return pl.pallas_call(add_vectors_kernel,
8 out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype))(x, y)
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:456, in wrapped()
455 which_linear = (False,) * len(flat_args)
--> 456 out_flat = pallas_call_p.bind(
457 *consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
458 in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
459 for a in flat_args),
460 out_shapes=tuple(flat_out_shapes), debug=debug,
461 interpret=interpret,
462 grid_mapping=grid_mapping,
463 input_output_aliases=tuple(input_output_aliases.items()),
464 **compiler_params)
465 out = tree_util.tree_unflatten(out_tree, out_flat)
JaxStackTraceBeforeTransformation: ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[4], line 1
----> 1 add_vectors(jnp.arange(8), jnp.arange(8))
[... skipping hidden 23 frame]
File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:415, in _pallas_call_default_lowering(ctx, interpret, *in_nodes, **params)
411 raise ValueError("Only interpret mode is supported on CPU backend.")
412 # If we are actually using a specific backend (GPU or TPU), we should have
413 # already registered backend-specific lowerings. If we get this far, it means
414 # those backends aren't present.
--> 415 raise ValueError(
416 f"Cannot lower pallas_call on platform: {platform}. "
417 "To use Pallas on GPU, please install Triton and JAX-Triton. "
418 "To use Pallas on TPU, please install Jaxlib TPU and libtpu.")
ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.
Can you try running this Python snippet?
import triton
from triton.compiler import code_generator as code_gen
from triton.compiler import compiler as tc
import triton.language as tl
from triton.runtime import autotuner
import triton._C.libtriton.triton as _triton
from triton.common.backend import get_backend
import triton.compiler.backends.cuda as cb
Sorry, I missed a warning during the install:
jax-triton 0.1.4 requires absl-py>=1.4.0, which is not installed.
Installing the specified package, then reinstalling the packages worked :)
Ah so the commands don't work exactly?
I think it depends on your environment and whether you have that package preinstalled or not. I think adding absl-py to the jax-triton install line should be sufficient, but I haven't time to validate that immediately.
I'll note that absl-py is not a dependency of JAX itself, but it is a dependency of JAX's tests, so it's likely that many JAX users don't have it installed while most JAX developers do.
This is great, thank you Sharad! I was able to run the quick start items on an A100. I notice that it's pretty slow to compile, is that normal? Also, I'm getting this warning:
2024-01-04 17:27:28.729972: W external/xla/xla/service/gpu/command_buffer_scheduling.cc:470] Removed command buffer support for CUBLAS as it's not supported with gpu toolkit version 12020 and driver version 12020. This might negatively impact peformance. To enable CUBLAS support in command buffers use cuda-compat package: https://docs.nvidia.com/deploy/cuda-compatibility/.
Is it okay to ignore or will it affect performance?
Worked for me as well, thanks!
@karan-dalal It's safe to ignore that warning. It is fixed (its verbosity was turned down) at head.
We are now publishing containers from https://github.com/NVIDIA/JAX-Toolbox that include the Pallas dependencies. You can find these as special tags of the jax container, for example the latest one is ghcr.io/nvidia/jax:latest-pallas; this is the same as ghcr.io/nvidia/jax:latest but with compatible versions of triton and jax-triton installed as well.
While we are still ironing out bugs here, it might be useful to use an older version such as https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax/168056269?tag=nightly-pallas-2023-12-16.
Hopefully this is an easy way to get started with Pallas on GPU!
We are now publishing containers from https://github.com/NVIDIA/JAX-Toolbox that include the Pallas dependencies. You can find these as special tags of the
jaxcontainer, for example the latest one isghcr.io/nvidia/jax:latest-pallas; this is the same asghcr.io/nvidia/jax:latestbut with compatible versions oftritonandjax-tritoninstalled as well.While we are still ironing out bugs here, it might be useful to use an older version such as https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax/168056269?tag=nightly-pallas-2023-12-16.
Hopefully this is an easy way to get started with Pallas on GPU!
Note that there has been a small reorganisation to how these containers are labelled.
The latest one right now is ghcr.io/nvidia/jax:pallas-2024-02-09, and they can be found under this link.
Hi everyone, the latest jaxlib version (0.4.25) no longer requires neither jax_triton nor triton packages to compile Pallas kernels on GPU.
Please give it a try and let us know if you run into any issues.
Just wanted to commend the amazing effort Sergei put into this. He enabled a brand new lowering path for Pallas that purely goes through C++ by emitting the MLIR that Triton normally emits. It was a very tricky thing to get right!
@superbobry, I am trying to debug some performance degradation errors so I need to run some older versions (Jax v0.4.22-v0.4.24). I've tried several combinations of triton and jax_trition but haven't been able to make them work for Jax v0.4.24. I can make it work for Jax 0.4.25, 0.4.23, and 0.4.21, just not v0.4.24.
Could you inform me which versions of triton and jax_triton are compatible with Jax v0.4.24? I am running the example found at https://jax.readthedocs.io/en/latest/pallas/quickstart.html, and I encountered the following error:
Traceback (most recent call last):
File "/root/test/test_pallas.py", line 18, in <module>
add_vectors(jnp.arange(8), jnp.arange(8))
File "/root/test/test_pallas.py", line 15, in add_vectors
return pl.pallas_call(add_vectors_kernel,
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/pallas_call.py", line 532, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
I often found the error reporting is not accurate and it does not exactly tell the source of the errors. I tried this commit of jax_triton (https://github.com/jax-ml/jax-triton/commit/28ad4766271a181587e6e17e17de7f729c1a03b5) and the triton commit mentioned in the commit message. It reports the error above.
Unfortunately, I'm not sure. I tried finding a working combination for 0.4.24 myself some time ago, and couldn't get it to work.
Do you suspect that the performance degradation is due to changes in Pallas?
I'm trying to test out the splash attention kernel on GPU and getting the error "NotImplementedError: Scalar prefetch not supported in Triton lowering." My test works fine on TPU-V4 however. I installed the latest version of jax for GPU.
here's the test:
from jax.experimental.pallas.ops.tpu.splash_attention import make_splash_mha, CausalMask, MultiHeadMask, SegmentIds
import jax.numpy as jnp
import jax
splash = make_splash_mha(
mask=MultiHeadMask([CausalMask((128, 128)) for _ in range(8)]),
head_shards=1,
q_seq_shards=1,
)
qs = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)
ks = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)
vs = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)
segment_ids = SegmentIds(
jnp.asarray(([0]*96)+([1]*32), dtype=jnp.int32),
jnp.asarray(([0]*96)+([1]*32), dtype=jnp.int32),
)
output = splash(
q=qs, k=ks, v=vs, segment_ids=segment_ids,
)
print(output.shape)
Hey @Sea-Snell, some features of Pallas TPU are indeed unsupported on GPU. Thus, the TPU-specific implementation of splash attention you are using does not currently work on GPU.