cubed icon indicating copy to clipboard operation
cubed copied to clipboard

Jax integration

Open alxmrs opened this issue 2 years ago • 10 comments

Can the core array API ops of cubed be implemented in jax, s.t. everything easily compiles to accelerators? Could this solve the common pain point of running out of GPU memory? How would other constraints (GPU bandwidth limits) be handled? What is the ideal distributed runtime environment to make the most of this? Could spot GPU instances be used (serverless accelerators)?

alxmrs avatar Sep 10 '23 05:09 alxmrs

Thanks @alxmrs for opening this issue. I'm not familiar enough with JAX or GPUs to answer these questions, but I'd be happy to support or discuss an initiative in this direction. Is there a small piece of work that you have in mind that could be used to explore this?

tomwhite avatar Sep 11 '23 09:09 tomwhite

The Jax docs may provide a few good toy examples useful to validate this idea.

Check out this tutorial: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

This tutorial on distributing computation on a pool of TPUs (specifically, the neural net section) may be of interest, too:

https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#examples-neural-networks

High level goals for Jax + Cubed may be to make managing GPU memory effortless:

  • https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#common-causes-of-oom-failures
  • E.g.: could we do an FFT on a 100 GiB (TiB? PiB??) array? https://github.com/google/jax/discussions/13842

alxmrs avatar Sep 19 '23 18:09 alxmrs

Thanks for the pointers @alxmrs.

Thinking about how Cubed might hook into this, the main idea in Cubed is that every array is backed by Zarr, and an operation maps one Zarr array to another, by working in parallel on chunks (see https://tom-e-white.com/cubed/cubed.slides.html#/1).

Every operation is ultimately expressible as a blockwise operation (or a rechunk operation, but let's ignore that here), which ends up in the apply_blockwise function:

https://github.com/tomwhite/cubed/blob/0d13e4f2b12c22d1c41b9f4ea693266b21d808d0/cubed/primitive/blockwise.py#L53-L75

The key parts are:

  1. Reading each arg from Zarr into (CPU) memory (line 67),
  2. Invoking the function on the args (line 70), and
  3. Writing the result from (CPU) memory back to Zarr (line 73 or 75)

To change this to use JAX, we'd have to 1. read from Zarr into JAX arrays, 2. invoke the relevant JAX function on the arrays, 3. write the resulting JAX array to Zarr.

In fact, there might not be anything to do for 2., since you could call cubed.map_blocks with a JAX function.

This might be enough for the FFT example, although I'm a bit hazy on if any post-processing is needed on the chunked (sharded) output.

A final thought. Is KvikIO, which does direct Zarr-GPU IO, related to this?

tomwhite avatar Sep 20 '23 09:09 tomwhite

read from Zarr into JAX arrays

A final thought. Is KvikIO, which does direct Zarr-GPU IO, related to this?

Reading the xarray blog post on this that @dcherian and @weiji14 wrote, it seems they used a Zarr store provided by kvikio. I expect cubed could use this to load data from Zarr direct to GPU in the form of a cupy array, which would be cool. (Or even you could probably use the xarray backend they wrote alongside cubed-xarray to achieve this.)

I tried to find if there was anything similar for JAX, but didn't see anything (only this https://github.com/google/jax/discussions/17534). Writing from JAX to tensorstore was done for checkpointing language models (https://blog.research.google/2022/09/tensorstore-for-high-performance.html?m=1) but one would have thought that making tensorstore return JAX arrays directly would have been tried...

TomNicholas avatar Feb 09 '24 19:02 TomNicholas

KvikIO loads data into cupy, but it should technically be possible to zero-copy cupy arrays to JAX, Pytorch, or any array library that implements conversion via dlpack or the __cuda_array_interface__ protocol. It looks like JAX supports this already (https://github.com/google/jax/issues/1100)? But I haven't tried this end to end yet. There's also NVIDIA DALI which seems to work with JAX (https://docs.nvidia.com/deeplearning/dali/archives/dali_1_32_0/user-guide/docs/plugins/jax_tutorials.html#jax), but the interface is a little less convenient since you need to setup a pipeline. Generally, the integration between RAPIDS AI libraries (which builds on cupy) is a bit better on the Pytorch side with the RAPIDS Memory Manager (https://github.com/rapidsai/rmm/blob/branch-24.04/README.md#using-rmm-with-third-party-libraries).

weiji14 avatar Feb 11 '24 22:02 weiji14

Hey @tomwhite, I have a question for you: to run jax arrays on accelerators (M1+ chips, GPUs, TPUs, etc.), someone needs to call jax.jit: https://jax.readthedocs.io/en/latest/quickstart.html#just-in-time-compilation-with-jax-jit

Where is a good place to make this kind of call within Cubed? Is this something that should be handled by an Executor (this seems not so ideal)?

(Here's some more-in-depth docs on Jax's jit: https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation).

A related concern that I haven't properly figured out yet: How should this intersect with devices and sharding? https://jax.readthedocs.io/en/latest/sharded-computation.html

alxmrs avatar Jul 23 '24 13:07 alxmrs

Where is a good place to make this kind of call within Cubed? Is this something that should be handled by an Executor (this seems not so ideal)?

Possibly as a part of DAG finalization, after (Cubed) optimization has been run. Although the function you want to jit will be the function in BlockwiseSpec: https://github.com/cubed-dev/cubed/blob/59c593d32ef09466f207e2011ca7a73057d21c47/cubed/primitive/blockwise.py#L47-L73

What's the simplest possible example to start with?

tomwhite avatar Jul 23 '24 15:07 tomwhite

Thanks for your suggestion, Tom. I've prototyped something here: https://github.com/alxmrs/cubed/pull/1

For now, it looks like I need to work on landing the M1 PR before I can take this any further.

alxmrs avatar Jul 23 '24 22:07 alxmrs

Nice!

tomwhite avatar Jul 24 '24 08:07 tomwhite

I think compiling the (Cubed optimized) blockwise functions using AOT compilation (as you mentioned in https://github.com/cubed-dev/cubed/issues/490#issuecomment-2247599475), and then exporting them so they can run in other processes (https://jax.readthedocs.io/en/latest/export/export.html) may be the way to go. Perhaps this is worth trying on CPU first.

tomwhite avatar Jul 30 '24 15:07 tomwhite