jax icon indicating copy to clipboard operation
jax copied to clipboard

Make Pallas/GPU easier to install

Open hawkinsp opened this issue 2 years ago • 20 comments

Currently it's very difficult to install Pallas and jax_triton, since you have to get compatible versions of everything, and it's very finicky to work out which they are. We should make this easier!

hawkinsp avatar Nov 20 '23 13:11 hawkinsp

I thought it would be useful to post here that the following steps got me a working installation on top of the ghcr.io/nvidia/jax:nightly-2023-12-08 container (from https://github.com/NVIDIA/JAX-Toolbox):

# pip install --no-deps 'jax-triton@git+https://github.com/jax-ml/jax-triton.git@test_588045313' # e4fd5cb21f40c3991a204479c3a1a0e3f0194e91
# cd /opt
# git clone -b llvm-head https://github.com/openai/triton.git # ca78acaf1a6cf68e2af8a68762ec852534ff0610
# cd triton/
# pip install -e python # quite slow
# cd /opt/jax
# git checkout test_588045313 # 56c46f36df6cfaa40253e839203f950481ce97cd
# build-jax.sh # seems to rebuild more than I expected, might be improvable

the test_588045313 branches refer to these PRs:

  • https://github.com/google/jax/pull/18838
  • https://github.com/jax-ml/jax-triton/pull/241

and IIUC we would prefer to build Triton from https://github.com/openxla/triton. Hopefully once those PRs land the recipe could be simplified a little.

Big ➕ to the sentiment of this issue: this should be easier!

olupton avatar Dec 11 '23 14:12 olupton

Yes please!

mehdiataei avatar Dec 13 '23 07:12 mehdiataei

+1! Started to play with https://github.com/google/maxtext but immediately ran into Pallas/Jax-Triton errors and after some effort haven't been able to get a working jax build with Pallas/jax-triton working.

CoderPat avatar Dec 15 '23 17:12 CoderPat

+1 as well, I spent quite a while installing different permutations of versions of things but I couldn't find one that worked.

davisyoshida avatar Dec 31 '23 07:12 davisyoshida

Another big +1 here, would really love to use Pallas but I cannot find a correct set of commands to install it correctly.

vvvm23 avatar Jan 03 '24 16:01 vvvm23

Hi everyone, I was able to get a set of mutually compatible version pins. I hope this unblocks folks temporarily while we push on the longer term solution (bundle Triton with jaxlib GPU).

Versions:

triton-nightly==2.1.0.post20231216005823
jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
jaxlib==0.4.24.dev20240103
jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c

Installation commands:

$ pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823
$ pip install -IU --pre jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$ pip install -IU --pre jaxlib[cuda12_pip]==0.4.24.dev20240103 -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
$ pip install --no-deps 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c'

I verified this worked with a simple kernel in colab but haven't yet run more extensive tests. Please let me know if these work for you.

sharadmv avatar Jan 04 '24 20:01 sharadmv

Thanks @sharadmv! Appreciate the quick workaround.

Here is a copy pasteable version (note, you need to be using python 3.9 or 3.10):

pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823
pip install -IU --pre jax==0.4.24.dev20240104 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install -IU --pre "jaxlib[cuda12_pip]==0.4.24.dev20240103" -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
pip install --no-deps 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git@7778c47c0a27c0988c914dce640dec61e44bbe8c'

EDIT: Though for me, this still didn't work locally, perhaps some interference with my local CUDA?

vvvm23 avatar Jan 04 '24 21:01 vvvm23

What's your error?

sharadmv avatar Jan 04 '24 21:01 sharadmv

When running the Hello World example in the Quickstart. This is in a fresh environment, followed by the block above, with CUDA12.3.1-2 locally

2024-01-04 21:03:19.992205: W external/xla/xla/service/gpu/command_buffer_scheduling.cc:470] Removed command buffer support for CUBLAS as it's not supported with gpu toolkit version 12020 and driver version 12030. This might negatively impact peformance. To enable CUBLAS support in command buffers use cuda-compat package: https://docs.nvidia.com/deploy/cuda-compatibility/.

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/.pyenv/versions/3.10.12/lib/python3.10/runpy.py:196, in _run_module_as_main()
    195     sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
    197                  "__main__", mod_spec)

File ~/.pyenv/versions/3.10.12/lib/python3.10/runpy.py:86, in _run_code()
     79 run_globals.update(__name__ = mod_name,
     80                    __file__ = fname,
     81                    __cached__ = cached,
   (...)
     84                    __package__ = pkg_name,
     85                    __spec__ = mod_spec)
---> 86 exec(code, run_globals)
     87 return run_globals

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel_launcher.py:17
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/traitlets/config/application.py:1075, in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelapp.py:701, in start()
    700 try:
--> 701     self.io_loop.start()
    702 except KeyboardInterrupt:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/tornado/platform/asyncio.py:205, in start()
    204 def start(self) -> None:
--> 205     self.asyncio_loop.run_forever()

File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/base_events.py:603, in run_forever()
    602 while True:
--> 603     self._run_once()
    604     if self._stopping:

File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/base_events.py:1909, in _run_once()
   1908     else:
-> 1909         handle._run()
   1910 handle = None

File ~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/events.py:80, in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:534, in dispatch_queue()
    533 try:
--> 534     await self.process_one()
    535 except Exception:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:523, in process_one()
    522         return
--> 523 await dispatch(*args)

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:429, in dispatch_shell()
    428     if inspect.isawaitable(result):
--> 429         await result
    430 except Exception:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:767, in execute_request()
    766 if inspect.isawaitable(reply_content):
--> 767     reply_content = await reply_content
    769 # Flush output before sending the reply.

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/ipkernel.py:429, in do_execute()
    428 if accepts_params["cell_id"]:
--> 429     res = shell.run_cell(
    430         code,
    431         store_history=store_history,
    432         silent=silent,
    433         cell_id=cell_id,
    434     )
    435 else:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3051, in run_cell()
   3050 try:
-> 3051     result = self._run_cell(
   3052         raw_cell, store_history, silent, shell_futures, cell_id
   3053     )
   3054 finally:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3106, in _run_cell()
   3105 try:
-> 3106     result = runner(coro)
   3107 except BaseException as e:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3311, in run_cell_async()
   3308 interactivity = "none" if silent else self.ast_node_interactivity
-> 3311 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3312        interactivity=interactivity, compiler=compiler, result=result)
   3314 self.last_execution_succeeded = not has_raised

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3493, in run_ast_nodes()
   3492     asy = compare(code)
-> 3493 if await self.run_code(code, result, async_=asy):
   3494     return True

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3553, in run_code()
   3552     else:
-> 3553         exec(code_obj, self.user_global_ns, self.user_ns)
   3554 finally:
   3555     # Reset our crash handler in place

Cell In[4], line 1
----> 1 add_vectors(jnp.arange(8), jnp.arange(8))

Cell In[3], line 7, in add_vectors()
      5 @jax.jit
      6 def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
----> 7   return pl.pallas_call(add_vectors_kernel,
      8                         out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype))(x, y)

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:456, in wrapped()
    455 which_linear = (False,) * len(flat_args)
--> 456 out_flat = pallas_call_p.bind(
    457     *consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
    458     in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
    459                     for a in flat_args),
    460     out_shapes=tuple(flat_out_shapes), debug=debug,
    461     interpret=interpret,
    462     grid_mapping=grid_mapping,
    463     input_output_aliases=tuple(input_output_aliases.items()),
    464     **compiler_params)
    465 out = tree_util.tree_unflatten(out_tree, out_flat)

JaxStackTraceBeforeTransformation: ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[4], line 1
----> 1 add_vectors(jnp.arange(8), jnp.arange(8))

    [... skipping hidden 23 frame]

File ~/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:415, in _pallas_call_default_lowering(ctx, interpret, *in_nodes, **params)
    411   raise ValueError("Only interpret mode is supported on CPU backend.")
    412 # If we are actually using a specific backend (GPU or TPU), we should have
    413 # already registered backend-specific lowerings. If we get this far, it means
    414 # those backends aren't present.
--> 415 raise ValueError(
    416     f"Cannot lower pallas_call on platform: {platform}. "
    417     "To use Pallas on GPU, please install Triton and JAX-Triton. "
    418     "To use Pallas on TPU, please install Jaxlib TPU and libtpu.")

ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.

vvvm23 avatar Jan 04 '24 21:01 vvvm23

Can you try running this Python snippet?

import triton
from triton.compiler import code_generator as code_gen
from triton.compiler import compiler as tc
import triton.language as tl
from triton.runtime import autotuner
import triton._C.libtriton.triton as _triton
from triton.common.backend import get_backend
import triton.compiler.backends.cuda as cb

sharadmv avatar Jan 04 '24 21:01 sharadmv

Sorry, I missed a warning during the install:

jax-triton 0.1.4 requires absl-py>=1.4.0, which is not installed.

Installing the specified package, then reinstalling the packages worked :)

vvvm23 avatar Jan 04 '24 21:01 vvvm23

Ah so the commands don't work exactly?

sharadmv avatar Jan 04 '24 21:01 sharadmv

I think it depends on your environment and whether you have that package preinstalled or not. I think adding absl-py to the jax-triton install line should be sufficient, but I haven't time to validate that immediately.

vvvm23 avatar Jan 04 '24 22:01 vvvm23

I'll note that absl-py is not a dependency of JAX itself, but it is a dependency of JAX's tests, so it's likely that many JAX users don't have it installed while most JAX developers do.

hawkinsp avatar Jan 04 '24 22:01 hawkinsp

This is great, thank you Sharad! I was able to run the quick start items on an A100. I notice that it's pretty slow to compile, is that normal? Also, I'm getting this warning:

2024-01-04 17:27:28.729972: W external/xla/xla/service/gpu/command_buffer_scheduling.cc:470] Removed command buffer support for CUBLAS as it's not supported with gpu toolkit version 12020 and driver version 12020. This might negatively impact peformance. To enable CUBLAS support in command buffers use cuda-compat package: https://docs.nvidia.com/deploy/cuda-compatibility/.

Is it okay to ignore or will it affect performance?

karan-dalal avatar Jan 05 '24 01:01 karan-dalal

Worked for me as well, thanks!

davisyoshida avatar Jan 05 '24 02:01 davisyoshida

@karan-dalal It's safe to ignore that warning. It is fixed (its verbosity was turned down) at head.

hawkinsp avatar Jan 05 '24 22:01 hawkinsp

We are now publishing containers from https://github.com/NVIDIA/JAX-Toolbox that include the Pallas dependencies. You can find these as special tags of the jax container, for example the latest one is ghcr.io/nvidia/jax:latest-pallas; this is the same as ghcr.io/nvidia/jax:latest but with compatible versions of triton and jax-triton installed as well.

While we are still ironing out bugs here, it might be useful to use an older version such as https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax/168056269?tag=nightly-pallas-2023-12-16.

Hopefully this is an easy way to get started with Pallas on GPU!

olupton avatar Jan 10 '24 15:01 olupton

We are now publishing containers from https://github.com/NVIDIA/JAX-Toolbox that include the Pallas dependencies. You can find these as special tags of the jax container, for example the latest one is ghcr.io/nvidia/jax:latest-pallas; this is the same as ghcr.io/nvidia/jax:latest but with compatible versions of triton and jax-triton installed as well.

While we are still ironing out bugs here, it might be useful to use an older version such as https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax/168056269?tag=nightly-pallas-2023-12-16.

Hopefully this is an easy way to get started with Pallas on GPU!

Note that there has been a small reorganisation to how these containers are labelled. The latest one right now is ghcr.io/nvidia/jax:pallas-2024-02-09, and they can be found under this link.

olupton avatar Feb 09 '24 16:02 olupton

Hi everyone, the latest jaxlib version (0.4.25) no longer requires neither jax_triton nor triton packages to compile Pallas kernels on GPU.

Please give it a try and let us know if you run into any issues.

superbobry avatar Feb 26 '24 17:02 superbobry

Just wanted to commend the amazing effort Sergei put into this. He enabled a brand new lowering path for Pallas that purely goes through C++ by emitting the MLIR that Triton normally emits. It was a very tricky thing to get right!

sharadmv avatar Feb 29 '24 02:02 sharadmv

@superbobry, I am trying to debug some performance degradation errors so I need to run some older versions (Jax v0.4.22-v0.4.24). I've tried several combinations of triton and jax_trition but haven't been able to make them work for Jax v0.4.24. I can make it work for Jax 0.4.25, 0.4.23, and 0.4.21, just not v0.4.24.

Could you inform me which versions of triton and jax_triton are compatible with Jax v0.4.24? I am running the example found at https://jax.readthedocs.io/en/latest/pallas/quickstart.html, and I encountered the following error:

Traceback (most recent call last):
  File "/root/test/test_pallas.py", line 18, in <module>
    add_vectors(jnp.arange(8), jnp.arange(8))
  File "/root/test/test_pallas.py", line 15, in add_vectors
    return pl.pallas_call(add_vectors_kernel,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/pallas_call.py", line 532, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Cannot lower pallas_call on platform: cuda. To use Pallas on GPU, please install Triton and JAX-Triton. To use Pallas on TPU, please install Jaxlib TPU and libtpu.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

I often found the error reporting is not accurate and it does not exactly tell the source of the errors. I tried this commit of jax_triton (https://github.com/jax-ml/jax-triton/commit/28ad4766271a181587e6e17e17de7f729c1a03b5) and the triton commit mentioned in the commit message. It reports the error above.

merrymercy avatar Mar 28 '24 22:03 merrymercy

Unfortunately, I'm not sure. I tried finding a working combination for 0.4.24 myself some time ago, and couldn't get it to work.

Do you suspect that the performance degradation is due to changes in Pallas?

superbobry avatar Mar 28 '24 23:03 superbobry

I'm trying to test out the splash attention kernel on GPU and getting the error "NotImplementedError: Scalar prefetch not supported in Triton lowering." My test works fine on TPU-V4 however. I installed the latest version of jax for GPU.

here's the test:

from jax.experimental.pallas.ops.tpu.splash_attention import make_splash_mha, CausalMask, MultiHeadMask, SegmentIds
import jax.numpy as jnp
import jax

splash = make_splash_mha(
    mask=MultiHeadMask([CausalMask((128, 128)) for _ in range(8)]),
    head_shards=1,
    q_seq_shards=1,
)

qs = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)
ks = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)
vs = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 256), dtype=jnp.float32)

segment_ids = SegmentIds(
    jnp.asarray(([0]*96)+([1]*32), dtype=jnp.int32),
    jnp.asarray(([0]*96)+([1]*32), dtype=jnp.int32),
)

output = splash(
    q=qs, k=ks, v=vs, segment_ids=segment_ids,
)

print(output.shape)

Sea-Snell avatar May 18 '24 17:05 Sea-Snell

Hey @Sea-Snell, some features of Pallas TPU are indeed unsupported on GPU. Thus, the TPU-specific implementation of splash attention you are using does not currently work on GPU.

superbobry avatar May 18 '24 20:05 superbobry