jax
jax copied to clipboard
Proposed Dockerfile for Running JAX
Demand for JAX Dockerfile
There is some demand for a Dockerfile that runs jax / jaxlib (as opposed to the one currently in the repository that is used to build wheels for jax/jaxlib), see here, here, and probably here. I've separately pointed out that tensorflow-gpu images coincidentally work fine.
tf-gpu is somewhat of a moving target, and it's probably wise to de-couple any jax Dockerfile from what tf-gpu does. As part of a package I'm writing on top of jax, I've put together a Dockerfile with "only what jaxlib needs", using the nvidia cuda images as a base (specifically the Ubuntu 20.04 cudnn devel image). I'm interested in having this upstreamed - in exchange, I'm committed to maintaining the Dockerfile for at least some time. First, I present the Dockerfile. Then I'll lay out some of the design choices (that's of course flexible to whatever the jax team wants).
Pre-Requisites for use
Use of the image pre-supposes the user has installed the nvidia docker toolkit on their docker host. This has been tested on a RHEL8 docker host on consumer hardware.
Standalone Dockerfile
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04
# declare the image name
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
# declare what jaxlib tag to use
# if a CI/CD system is expected to pass in these arguments
# the dockerfile should be modified accordingly
JAXLIB_VERSION=0.1.62
# install python3-pip
RUN apt update && apt install python3-pip -y
# install dependencies via pip
RUN pip3 install numpy scipy six wheel jaxlib==${JAXLIB_VERSION}+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html
I have no idea how this can integrate into the CI/CD pipeline that the jax team have for the project (although I have looked at some of the files in build/
, but am happy to work with the team on integration.
Design Choices
The Dockerfile is intended to abstract away the trickiest part of getting jax+cuda working, which is jaxlib. As such it does not include jax itself (this simplifies the build matrix, since only every supported jaxlib version would need to be built, not every jaxlib-jax tuple). It is intended that users will pull this image as part of their own tooling, and append whatever jax version they need (either from source or pypi).
I've also chosen to install the resulting jaxlib wheel globally (as opposed, for example, to creating a new "jax" user). This is similar to how the tf-gpu images work.
Legal Stuff
The use of the nvidia images is bound by their CUDA and CUDNN EULAs, and probably by Ubuntu licensing terms too. Not a lawyer, but whatever tf-gpu does (if any) to make their images compliant will probably work for the Dockerfile I've outlined above.
Thanks for getting this started! It sounds like this could help people get started with jax on GPU, although I'd like to understand the maintenance cost more. I personally have close to 0 docker experience, so I have a few basic questions:
-
What exactly does upstreaming mean here? I.e. just checking in the Dockerfile somewhere, or also hosting the image, or ...?
-
What kind of CI do you have in mind? We currently do all GPU testing using internal Google infrastructure, which isn't very docker-friendly. However, I'm looking into setting up GPU testing that runs in a non-Google environment, so we could potentially use the GPU image there.
cc @hawkinsp and @mattjj in case they have further thoughts.
Hi @skye, thanks for picking this up.
-
Maintenance Ideally, Google would have the Dockerfile in the repo, and also publish it under their name in somewhere like Dockerhub (or GCR, I guess). That way, users can just
docker pull google/jax-gpu:latest
and be off to the races. New docker images can be built upon release (so for example the versioning would be google/jax-gpu:0.1.64 or similar).- The Dockerfile I've written will basically need to dynamically updated with what JAXLIB version to build with (passing it in as an environment variable or similar).
- The image intentionally only installs jaxlib and not jax (because jaxlib is the thing that's tricky to get going CUDA-acceleration wise). This leaves flexibility for downstream projects (such as NumPyro or Optax) to layer on their choice of jax version (if they practice version pinning for their releases).
-
What CI? Not employed or paid to shill any particular company, but I was looking at something like Drone CI for my own project. Basically, any time I push a new release tag, I'll get Drone CI to pull up an image (based on the work here), clone the repo into the Docker image, then run
pytest <whatever>
, and report the results.- This has the upside of checking that the Docker image will work well for development work (see below). So anytime anyone has a "why doesn't jaxlib work with CUDA on my machine", instead of debugging where the n-th
*.so
file is (been there. It's painful), the JAX team can say "have you tried running your code inside ourDocker image
?" <- this doesn't force everyone to use the same version of Ubuntu as their OS.
- This has the upside of checking that the Docker image will work well for development work (see below). So anytime anyone has a "why doesn't jaxlib work with CUDA on my machine", instead of debugging where the n-th
The package I'm building is actually layered on top of JAX and NumPyro (which is Pyro + JAX backend, providing Bayesian inference - NUTS and so forth, the usual MCMC stuff). As such, I've proposed the same to them which they have accepted. Right now, we're tugging Google's storage api for the wheels (NumPyro version-pins jax/jaxlib for release, and I'm using pip-versions to parse the latest for a development image). If Google posted Docker images (or similar) online, then any library layering on top of Google would just start their image on top of Google's (say google/jax-gpu:0.1.64 or google/jax-gpu:latest) and go from there.
Development Use-Case
As I pointed out in my NumPyro PR, this makes it easier for people to hack on and develop CUDA-accelerated JAX too. My personal workflow uses VSCode (I'm sure Google has different internal tooling) and their Remote Development extension. Basically, VSCode mounts my workspace as a volume inside the container, so I can run code "as if the Docker container was my OS-level venv" and that helps standardize a dev environment for future collaborators.
My Maintenance Promise
I have some experience working with Docker at least small-scale (I create project-level Dockerfiles for all the work I do to ensure reproducibility) - but I don't usually publish any of my images. I'm happy to help maintain the Dockerfile (oh no! the Docker image we built complains about *.so
being in the wrong place! Where is it now?! We should update to Ubuntu 22.04 - oh wait they changed the *.so
location again...), and run some rudimentary tests on the consumer hardware that I have available (no multi-hour test suites please).
Things That Could Go Wrong
The main pain point would be if the pre-built wheels no longer work with nvidia's cudnn images. Right now building the image is relatively painless. I'm just pip-installing the appropriate jaxlib wheel from the jax-provided Google link. If in the future the build environment for those wheels and the container diverge (can't see a reason for that to happen, but you know how these things go...), then we would be in trouble. In an earlier iteration of the Dockerfile I tried building the wheel inside the Docker image (which would need to be done once per image version by the publisher, Google, and then be unnecessary for downstream users). It took a long time on my consumer grade hardware, and didn't end up working (it built successfully, but wasn't able to find the correct combination of *.so
locations for everything to fall into place).
Hopefully if the build environment for the wheels and the Dockerfile proposed here diverge, we can catch it early, and I can modify the Dockerfile to match the build environment. (In the worse case, we would need to do as Tensorflow does and start from an Ubuntu base image, and specify how CUDA+CUDNN is installed. I'm really hoping to avoid that).
Thanks for the detailed reply. I think we should give this a shot, with the provision that we don't guarantee perpetual support for the image(s) until we work out any kinks and validate that the cost/benefit ratio makes sense (I think it will!).
Your design choices and maintenance plan sound great.
CI testing is trickier. Drone looks pretty nice, and has the critical feature that we can run it on our own VMs, which is important for GPU testing. However, running our own CI servers is more than I'd like to bite off for now, especially if they're publicly accessible. The GPU CI I'm currently setting up might work here -- the plan is to build a GPU jaxlib from source and run our unit tests with it every night, so I can set up a separate job that only runs the unit tests using the docker image with a fresh jax checkout, either nightly or manually triggered for a release. The big downside here is that the tooling I'm using is Google-internal, so you wouldn't be able to trigger or view test results. I can alert you if anything goes wrong though :)
Putting it all together, what do you think of this plan? (Others feel free to chime in too!)
- You check in the Dockerfile, as well instructions for how to update the image and upload to Dockerhub/GCR for when a new jaxlib is released (ideally you'd take care of this, at least initially, but it's important that the jax team knows how to do it just in case).
- I'll set up (internal) CI testing for the image and monitor it. I'll reach out to you if I need help using the image (like I said, ~0 docker experience) or if the CI fails in a non-trivial way.
- You keep the docker file updated and fix any issues that we report based on the CI testing, at least for a while.
One risk of this plan is that I haven't actually set up the GPU testing yet, and while I'm hoping to get it working soon, it might be delayed indefinitely if it's more difficult than I'm anticipating. So we might wanna swap steps 1 and 2 in the above plan. I'm also open to more OS-friendly build/test options, but that'll be a longer discussion that we probably shouldn't block this docker image on.
Hi @skye, I've looked into the CI thing. If there is interest, the Google team can configure the repository to build and publish images to Dockerhub automatically on release / push. It's free, doesn't handle testing, but it does automate the image building process. I think if we're not building from source it's okay (i.e., just pip-installing). But if we're building from source GitHub might complain (so this is good for a :version build upon release). The current Dockerfile will assume that Google built binaries at the storageapi website are available prior to a release being tagged in the public GitHub repo (otherwise the image build will fail since it tries to pull a non-existent wheel).
I will over the next couple of weeks:
-
Create a PR. I'll add a separate
docker/
folder to the root of the repository, which will comprise:- A README with what the images are, how to use them, and how to update the image / upload to Dockerhub (which is free)
-
version/
containing a Dockerfile like the one above. The user manually specifies the jaxlib version to use (or comments it out so that the version number can be passed in as an environment variable). -
dev/
which usespip-versions
to install the latest released jaxlib. This should be swapped out eventually with a "build from current source" image (a 'nightly' or 'dev' branch). I'll also include the sketch of a Dockerfile for this swap (but it won't be working, I'll explain why).
-
Build and distribute the resulting images under my personal account on Dockerhub, which will link back to the README in the jax repository.
- For convenience, I'll have a
jaxlib-gpu:with-jax
tag that will take the latest version of CUDA jaxlib and add on the latest version of jax. (I won't do every possible combination of jax+jaxlib)
- For convenience, I'll have a
-
Right now, builds will be manual (I don't have the bandwidth to setup a fork of the jax repo just to track release and automate everything). I'll build an image for the last 4 releases of jaxlib, and build
jaxlib-gpu:with-jax
plusdev
tags maybe once a week or so with a rudimentary cron job. -
I'll draft a GitHub-Actions based workflow that will build a new image and tag it accordingly upon release. Do you want this in a separate PR? (I'll modify the path so that it won't actually trigger, just so the team can see the config)
What I'm hoping the jax team can provide is:
- Recommendations on how to test the images I build in a way that won't kill regular consumer hardware. Do I just follow the instructions here with a low n? (Btw I think there's a small typo there where
ip
should bepip
) - A Dockerhub ID so I can add the relevant account as a collaborator in the Dockerhub repository for push access. If Google takes over distribution officially, I'll deactivate the repository on Dockerhub.
- Teamwork to figure out the build-from-source image together (it will be an adventure, for sure).
I'll of course cooperate with any testing / setup that the JAX team needs. (The use of CUDA in Docker requires the nvidia container toolkit which only supports Debian, Ubuntu, RHEL, CentOS - but it's possible to get it working on Fedora. I don't know what Google uses internally) Happy to have a conversation either in issues or offline about this.
Your plan sounds great!
I'll draft a GitHub-Actions based workflow that will build a new image and tag it accordingly upon release. Do you want this in a separate PR? (I'll modify the path so that it won't actually trigger, just so the team can see the config)
This is awesome. Separate PR please (I generally prefer smaller PRs whenever reasonable).
Recommendations on how to test the images I build in a way that won't kill regular consumer hardware. Do I just follow the instructions here with a low n?
Just to clarify, you're running this on a GPU machine? In my experience, running with -n auto
on GPU, or any -n
at all, results in OOMs, so unfortunately running the GPU tests is pretty slow. Something like:
JAX_NUM_GENERATED_CASES=10 pytest tests
should work, but will still take a while. You can experiment with -n 2
or some other low number if you want. Note that there are currently some failures at head on GPU too. The following is probably sufficient for a smoke test to make sure it's not completely broken, if you're willing to depend on post-facto CI testing for full validation:
python tests/api_test.py
A Dockerhub ID so I can add the relevant account as a collaborator in the Dockerhub repository for push access.
I just created an account with ID skyewm
. Do you think we should make an "official" JAX ID?
Teamwork to figure out the build-from-source image together (it will be an adventure, for sure).
We currently use one of TensorFlow's build images to build GPU jaxlibs (see the jaxlib build Dockerfile), which works well because most of jaxlib is a large TensorFlow dependency (since the XLA compiler is in TF's repo). I'd like to avoid maintaining our own build image that's compatible with TF's build when they already supply one, so would it be possible to run the TF docker image inside the dev
image build? I recently refactored build_jaxlib_wheels.sh into helper functions, to make it easier to build a single CUDA wheel, so hopefully we can just call build_cuda_wheels
from the dev
Dockerfile.
Hi @skye , thanks for getting back to me. RE: JAX docker ID, that can be for the future. For the moment, I just want a Google contact I can put as a collaborator in the Dockerhub repo for accountability purposes.
The point about tensorflow is worth discussing. I suspected as much (that JAX was using TensorFlow-GPU images, see my thread in the JAX discussions). Right now, my dockerfile intentionally uses Nvidia's CUDNN image as a base. This is useful to me (and any downstream projects) because we can then use JAX's images as a base for our own without TensorFlow branding (and bloat). If we use TF images for the dev
image build, but not for the release tagged builds, then the dev image build environment might end up being different from the release images (but I'm okay to do this too).
If the JAX team wants images that are based on TensorFlow-GPU instead (whether for dev or for release, or both), I can modify the dockerfiles accordingly (really, it's just changing the line FROM nvidia/
I'll leave it to the JAX team what image I should use as the base. (Right now, the TF-GPU images are compatible with nvidia's CUDNN images directly - they are after all both based on Ubtuntu 18.04+) You are absolutely correct that hopefully, this means that in the future JAX won't need to distribute wheels for every CUDA / CUDNN version. They can instead say "if you want to use JAX+CUDA, either use the docker image we provide, or build from source" <- this would probably be easier than the current system.
I wasn't suggesting basing the dev
image off the TF image. Instead, I was proposing we use the regular cuda build script (via the build_cuda_wheels
bash function I linked above) in order to build jaxlib in the dev
image. That in turn runs the TF docker image to build the wheel (hopefully we can consider this an implementation detail), and then copies the built wheel out of the TF image back into the calling environment. So we wouldn't be building in the stripped-down image, but we'd still be testing in that image after the build is finished. Does this make sense?
Thanks for this, it was the only way i was able to get GPU backend running locally. For anyone reading now i had to modify the container as such:
FROM nvidia/cuda:11.6.0-devel-ubuntu20.04
# declare the image name
ENV IMG_NAME=11.6.0-devel-ubuntu20.04 \
# declare what jaxlib tag to use
# if a CI/CD system is expected to pass in these arguments
# the dockerfile should be modified accordingly
JAXLIB_VERSION=0.3.0
# install python3-pip
RUN apt update && apt install python3-pip -y
# install dependencies via pip
RUN pip3 install numpy scipy six wheel jaxlib==${JAXLIB_VERSION}+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_releases.html jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
This seems to be working but maybe I have made a mistake
FROM nvidia/cuda:11.7.0-devel-ubuntu20.04
# install python3-pip
RUN apt update && apt install python3-pip -y
RUN pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+1 I have just spent hours trying to get jax pmap to work with 3090s. It is way too hard to find docker containers & jax versions that work, actually I couldn't find one that all works. Would love if the jax project peeps could provide some dockerhub images.
@sudhakarsingh27 @nvcastet @nouiz for viz
@mjsML in particular, on dual 3090 system:
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
base image nvidia/cuda:11.2.1-cudnn8-devel-ubuntu20.04
pip3 install install jax[cuda11_cudnn82]==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
XLA_PYTHON_CLIENT_PREALLOCATE=false
but nothing else special
idk single-GPU works, but i mean today pytorch dockerhub images work way way better and for multi-gpu yes
+1 We really need a docker image running jax, no matter for deployment or development. It relly helps;)
Maybe this the sustainable answer for JAX on multi-node, multi-gpu. JAX Container Early Access
NVIDIA official images are coming soon 😄 @pwais @lukaemon
@pwais One possible cause for your multi-GPU problem is that you need to provide SHM space to your docker container for the use of NCCL. Try setting shm-size
to something like 2GB.
The NVIDIA JAX early access (EA) container is now released, register to get it :)
This seems to be working but maybe I have made a mistake
FROM nvidia/cuda:11.7.0-devel-ubuntu20.04 # install python3-pip RUN apt update && apt install python3-pip -y RUN pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
nvidia/cuda:11.7.0-devel-ubuntu20.04
ships with cudnn 8.5 but "jax[cuda]" only works with cudnn 8.6+. The following works for me:
FROM nvidia/cuda:11.7.0-devel-ubuntu20.04
# install python3-pip
RUN apt update && apt install python3-pip -y
RUN pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Otherwise, use nvidia/cuda:11.8.0-devel-ubuntu20.04
which ships with cudnn 8.7
FROM nvidia/cuda:11.8.0-devel-ubuntu20.04
# install python3-pip
RUN apt update && apt install python3-pip -y
RUN pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
What currently version of FROM nvidia/cuda:
we could use with jax today?
I guess nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 should work if you update the JAX/JAXlib version installed.
None of Dockerfile
configurations worked and I ended up using NVIDIA Jax containers:
docker pull ghcr.io/nvidia/jax:nightly-2023-09-09
JAX Toolbox includes the nightly container referenced above along with open Dockerfiles.