jax icon indicating copy to clipboard operation
jax copied to clipboard

Proposed Dockerfile for Running JAX

Open AndrewCSQ opened this issue 3 years ago • 13 comments

Demand for JAX Dockerfile

There is some demand for a Dockerfile that runs jax / jaxlib (as opposed to the one currently in the repository that is used to build wheels for jax/jaxlib), see here, here, and probably here. I've separately pointed out that tensorflow-gpu images coincidentally work fine.

tf-gpu is somewhat of a moving target, and it's probably wise to de-couple any jax Dockerfile from what tf-gpu does. As part of a package I'm writing on top of jax, I've put together a Dockerfile with "only what jaxlib needs", using the nvidia cuda images as a base (specifically the Ubuntu 20.04 cudnn devel image). I'm interested in having this upstreamed - in exchange, I'm committed to maintaining the Dockerfile for at least some time. First, I present the Dockerfile. Then I'll lay out some of the design choices (that's of course flexible to whatever the jax team wants).

Pre-Requisites for use

Use of the image pre-supposes the user has installed the nvidia docker toolkit on their docker host. This has been tested on a RHEL8 docker host on consumer hardware.

Standalone Dockerfile

FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04

# declare the image name
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
    # declare what jaxlib tag to use
    # if a CI/CD system is expected to pass in these arguments
    # the dockerfile should be modified accordingly
    JAXLIB_VERSION=0.1.62

# install python3-pip
RUN apt update && apt install python3-pip -y

# install dependencies via pip
RUN pip3 install numpy scipy six wheel jaxlib==${JAXLIB_VERSION}+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I have no idea how this can integrate into the CI/CD pipeline that the jax team have for the project (although I have looked at some of the files in build/, but am happy to work with the team on integration.

Design Choices

The Dockerfile is intended to abstract away the trickiest part of getting jax+cuda working, which is jaxlib. As such it does not include jax itself (this simplifies the build matrix, since only every supported jaxlib version would need to be built, not every jaxlib-jax tuple). It is intended that users will pull this image as part of their own tooling, and append whatever jax version they need (either from source or pypi).

I've also chosen to install the resulting jaxlib wheel globally (as opposed, for example, to creating a new "jax" user). This is similar to how the tf-gpu images work.

Legal Stuff

The use of the nvidia images is bound by their CUDA and CUDNN EULAs, and probably by Ubuntu licensing terms too. Not a lawyer, but whatever tf-gpu does (if any) to make their images compliant will probably work for the Dockerfile I've outlined above.

AndrewCSQ avatar Apr 04 '21 15:04 AndrewCSQ

Thanks for getting this started! It sounds like this could help people get started with jax on GPU, although I'd like to understand the maintenance cost more. I personally have close to 0 docker experience, so I have a few basic questions:

  1. What exactly does upstreaming mean here? I.e. just checking in the Dockerfile somewhere, or also hosting the image, or ...?

  2. What kind of CI do you have in mind? We currently do all GPU testing using internal Google infrastructure, which isn't very docker-friendly. However, I'm looking into setting up GPU testing that runs in a non-Google environment, so we could potentially use the GPU image there.

cc @hawkinsp and @mattjj in case they have further thoughts.

skye avatar Apr 08 '21 23:04 skye

Hi @skye, thanks for picking this up.

  1. Maintenance Ideally, Google would have the Dockerfile in the repo, and also publish it under their name in somewhere like Dockerhub (or GCR, I guess). That way, users can just docker pull google/jax-gpu:latest and be off to the races. New docker images can be built upon release (so for example the versioning would be google/jax-gpu:0.1.64 or similar).
    • The Dockerfile I've written will basically need to dynamically updated with what JAXLIB version to build with (passing it in as an environment variable or similar).
    • The image intentionally only installs jaxlib and not jax (because jaxlib is the thing that's tricky to get going CUDA-acceleration wise). This leaves flexibility for downstream projects (such as NumPyro or Optax) to layer on their choice of jax version (if they practice version pinning for their releases).
  2. What CI? Not employed or paid to shill any particular company, but I was looking at something like Drone CI for my own project. Basically, any time I push a new release tag, I'll get Drone CI to pull up an image (based on the work here), clone the repo into the Docker image, then run pytest <whatever>, and report the results.
    • This has the upside of checking that the Docker image will work well for development work (see below). So anytime anyone has a "why doesn't jaxlib work with CUDA on my machine", instead of debugging where the n-th *.so file is (been there. It's painful), the JAX team can say "have you tried running your code inside our Docker image?" <- this doesn't force everyone to use the same version of Ubuntu as their OS.

The package I'm building is actually layered on top of JAX and NumPyro (which is Pyro + JAX backend, providing Bayesian inference - NUTS and so forth, the usual MCMC stuff). As such, I've proposed the same to them which they have accepted. Right now, we're tugging Google's storage api for the wheels (NumPyro version-pins jax/jaxlib for release, and I'm using pip-versions to parse the latest for a development image). If Google posted Docker images (or similar) online, then any library layering on top of Google would just start their image on top of Google's (say google/jax-gpu:0.1.64 or google/jax-gpu:latest) and go from there.

Development Use-Case

As I pointed out in my NumPyro PR, this makes it easier for people to hack on and develop CUDA-accelerated JAX too. My personal workflow uses VSCode (I'm sure Google has different internal tooling) and their Remote Development extension. Basically, VSCode mounts my workspace as a volume inside the container, so I can run code "as if the Docker container was my OS-level venv" and that helps standardize a dev environment for future collaborators.

My Maintenance Promise

I have some experience working with Docker at least small-scale (I create project-level Dockerfiles for all the work I do to ensure reproducibility) - but I don't usually publish any of my images. I'm happy to help maintain the Dockerfile (oh no! the Docker image we built complains about *.so being in the wrong place! Where is it now?! We should update to Ubuntu 22.04 - oh wait they changed the *.so location again...), and run some rudimentary tests on the consumer hardware that I have available (no multi-hour test suites please).

Things That Could Go Wrong

The main pain point would be if the pre-built wheels no longer work with nvidia's cudnn images. Right now building the image is relatively painless. I'm just pip-installing the appropriate jaxlib wheel from the jax-provided Google link. If in the future the build environment for those wheels and the container diverge (can't see a reason for that to happen, but you know how these things go...), then we would be in trouble. In an earlier iteration of the Dockerfile I tried building the wheel inside the Docker image (which would need to be done once per image version by the publisher, Google, and then be unnecessary for downstream users). It took a long time on my consumer grade hardware, and didn't end up working (it built successfully, but wasn't able to find the correct combination of *.so locations for everything to fall into place).

Hopefully if the build environment for the wheels and the Dockerfile proposed here diverge, we can catch it early, and I can modify the Dockerfile to match the build environment. (In the worse case, we would need to do as Tensorflow does and start from an Ubuntu base image, and specify how CUDA+CUDNN is installed. I'm really hoping to avoid that).

AndrewCSQ avatar Apr 08 '21 23:04 AndrewCSQ

Thanks for the detailed reply. I think we should give this a shot, with the provision that we don't guarantee perpetual support for the image(s) until we work out any kinks and validate that the cost/benefit ratio makes sense (I think it will!).

Your design choices and maintenance plan sound great.

CI testing is trickier. Drone looks pretty nice, and has the critical feature that we can run it on our own VMs, which is important for GPU testing. However, running our own CI servers is more than I'd like to bite off for now, especially if they're publicly accessible. The GPU CI I'm currently setting up might work here -- the plan is to build a GPU jaxlib from source and run our unit tests with it every night, so I can set up a separate job that only runs the unit tests using the docker image with a fresh jax checkout, either nightly or manually triggered for a release. The big downside here is that the tooling I'm using is Google-internal, so you wouldn't be able to trigger or view test results. I can alert you if anything goes wrong though :)

Putting it all together, what do you think of this plan? (Others feel free to chime in too!)

  1. You check in the Dockerfile, as well instructions for how to update the image and upload to Dockerhub/GCR for when a new jaxlib is released (ideally you'd take care of this, at least initially, but it's important that the jax team knows how to do it just in case).
  2. I'll set up (internal) CI testing for the image and monitor it. I'll reach out to you if I need help using the image (like I said, ~0 docker experience) or if the CI fails in a non-trivial way.
  3. You keep the docker file updated and fix any issues that we report based on the CI testing, at least for a while.

One risk of this plan is that I haven't actually set up the GPU testing yet, and while I'm hoping to get it working soon, it might be delayed indefinitely if it's more difficult than I'm anticipating. So we might wanna swap steps 1 and 2 in the above plan. I'm also open to more OS-friendly build/test options, but that'll be a longer discussion that we probably shouldn't block this docker image on.

skye avatar Apr 09 '21 01:04 skye

Hi @skye, I've looked into the CI thing. If there is interest, the Google team can configure the repository to build and publish images to Dockerhub automatically on release / push. It's free, doesn't handle testing, but it does automate the image building process. I think if we're not building from source it's okay (i.e., just pip-installing). But if we're building from source GitHub might complain (so this is good for a :version build upon release). The current Dockerfile will assume that Google built binaries at the storageapi website are available prior to a release being tagged in the public GitHub repo (otherwise the image build will fail since it tries to pull a non-existent wheel).

I will over the next couple of weeks:

  1. Create a PR. I'll add a separate docker/ folder to the root of the repository, which will comprise:

    • A README with what the images are, how to use them, and how to update the image / upload to Dockerhub (which is free)
    • version/ containing a Dockerfile like the one above. The user manually specifies the jaxlib version to use (or comments it out so that the version number can be passed in as an environment variable).
    • dev/ which uses pip-versions to install the latest released jaxlib. This should be swapped out eventually with a "build from current source" image (a 'nightly' or 'dev' branch). I'll also include the sketch of a Dockerfile for this swap (but it won't be working, I'll explain why).
  2. Build and distribute the resulting images under my personal account on Dockerhub, which will link back to the README in the jax repository.

    • For convenience, I'll have a jaxlib-gpu:with-jax tag that will take the latest version of CUDA jaxlib and add on the latest version of jax. (I won't do every possible combination of jax+jaxlib)
  3. Right now, builds will be manual (I don't have the bandwidth to setup a fork of the jax repo just to track release and automate everything). I'll build an image for the last 4 releases of jaxlib, and build jaxlib-gpu:with-jax plus dev tags maybe once a week or so with a rudimentary cron job.

  4. I'll draft a GitHub-Actions based workflow that will build a new image and tag it accordingly upon release. Do you want this in a separate PR? (I'll modify the path so that it won't actually trigger, just so the team can see the config)

What I'm hoping the jax team can provide is:

  1. Recommendations on how to test the images I build in a way that won't kill regular consumer hardware. Do I just follow the instructions here with a low n? (Btw I think there's a small typo there where ip should be pip)
  2. A Dockerhub ID so I can add the relevant account as a collaborator in the Dockerhub repository for push access. If Google takes over distribution officially, I'll deactivate the repository on Dockerhub.
  3. Teamwork to figure out the build-from-source image together (it will be an adventure, for sure).

I'll of course cooperate with any testing / setup that the JAX team needs. (The use of CUDA in Docker requires the nvidia container toolkit which only supports Debian, Ubuntu, RHEL, CentOS - but it's possible to get it working on Fedora. I don't know what Google uses internally) Happy to have a conversation either in issues or offline about this.

AndrewCSQ avatar Apr 09 '21 02:04 AndrewCSQ

Your plan sounds great!

I'll draft a GitHub-Actions based workflow that will build a new image and tag it accordingly upon release. Do you want this in a separate PR? (I'll modify the path so that it won't actually trigger, just so the team can see the config)

This is awesome. Separate PR please (I generally prefer smaller PRs whenever reasonable).

Recommendations on how to test the images I build in a way that won't kill regular consumer hardware. Do I just follow the instructions here with a low n?

Just to clarify, you're running this on a GPU machine? In my experience, running with -n auto on GPU, or any -n at all, results in OOMs, so unfortunately running the GPU tests is pretty slow. Something like: JAX_NUM_GENERATED_CASES=10 pytest tests should work, but will still take a while. You can experiment with -n 2 or some other low number if you want. Note that there are currently some failures at head on GPU too. The following is probably sufficient for a smoke test to make sure it's not completely broken, if you're willing to depend on post-facto CI testing for full validation: python tests/api_test.py

A Dockerhub ID so I can add the relevant account as a collaborator in the Dockerhub repository for push access.

I just created an account with ID skyewm. Do you think we should make an "official" JAX ID?

Teamwork to figure out the build-from-source image together (it will be an adventure, for sure).

We currently use one of TensorFlow's build images to build GPU jaxlibs (see the jaxlib build Dockerfile), which works well because most of jaxlib is a large TensorFlow dependency (since the XLA compiler is in TF's repo). I'd like to avoid maintaining our own build image that's compatible with TF's build when they already supply one, so would it be possible to run the TF docker image inside the dev image build? I recently refactored build_jaxlib_wheels.sh into helper functions, to make it easier to build a single CUDA wheel, so hopefully we can just call build_cuda_wheels from the dev Dockerfile.

skye avatar Apr 13 '21 18:04 skye

Hi @skye , thanks for getting back to me. RE: JAX docker ID, that can be for the future. For the moment, I just want a Google contact I can put as a collaborator in the Dockerhub repo for accountability purposes.

The point about tensorflow is worth discussing. I suspected as much (that JAX was using TensorFlow-GPU images, see my thread in the JAX discussions). Right now, my dockerfile intentionally uses Nvidia's CUDNN image as a base. This is useful to me (and any downstream projects) because we can then use JAX's images as a base for our own without TensorFlow branding (and bloat). If we use TF images for the dev image build, but not for the release tagged builds, then the dev image build environment might end up being different from the release images (but I'm okay to do this too).

If the JAX team wants images that are based on TensorFlow-GPU instead (whether for dev or for release, or both), I can modify the dockerfiles accordingly (really, it's just changing the line FROM nvidia/ to FROM tensorflow/). I'd be less incline to maintain it (since it wouldn't be as useful for me downstream), but I can surely see the appeal of it for the JAX team.

I'll leave it to the JAX team what image I should use as the base. (Right now, the TF-GPU images are compatible with nvidia's CUDNN images directly - they are after all both based on Ubtuntu 18.04+) You are absolutely correct that hopefully, this means that in the future JAX won't need to distribute wheels for every CUDA / CUDNN version. They can instead say "if you want to use JAX+CUDA, either use the docker image we provide, or build from source" <- this would probably be easier than the current system.

AndrewCSQ avatar Apr 13 '21 18:04 AndrewCSQ

I wasn't suggesting basing the dev image off the TF image. Instead, I was proposing we use the regular cuda build script (via the build_cuda_wheels bash function I linked above) in order to build jaxlib in the dev image. That in turn runs the TF docker image to build the wheel (hopefully we can consider this an implementation detail), and then copies the built wheel out of the TF image back into the calling environment. So we wouldn't be building in the stripped-down image, but we'd still be testing in that image after the build is finished. Does this make sense?

skye avatar Apr 13 '21 20:04 skye

Thanks for this, it was the only way i was able to get GPU backend running locally. For anyone reading now i had to modify the container as such:

FROM nvidia/cuda:11.6.0-devel-ubuntu20.04

# declare the image name
ENV IMG_NAME=11.6.0-devel-ubuntu20.04 \
    # declare what jaxlib tag to use
    # if a CI/CD system is expected to pass in these arguments
    # the dockerfile should be modified accordingly
    JAXLIB_VERSION=0.3.0

# install python3-pip
RUN apt update && apt install python3-pip -y

# install dependencies via pip
RUN pip3 install numpy scipy six wheel jaxlib==${JAXLIB_VERSION}+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_releases.html jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

MattLangsenkamp avatar Feb 24 '22 15:02 MattLangsenkamp

This seems to be working but maybe I have made a mistake

FROM nvidia/cuda:11.7.0-devel-ubuntu20.04

# install python3-pip
RUN apt update && apt install python3-pip -y

RUN pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

MatthewCaseres avatar Jul 30 '22 20:07 MatthewCaseres

+1 I have just spent hours trying to get jax pmap to work with 3090s. It is way too hard to find docker containers & jax versions that work, actually I couldn't find one that all works. Would love if the jax project peeps could provide some dockerhub images.

pwais avatar Aug 06 '22 21:08 pwais

@sudhakarsingh27 @nvcastet @nouiz for viz

mjsML avatar Aug 06 '22 21:08 mjsML

@mjsML in particular, on dual 3090 system:

jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

base image nvidia/cuda:11.2.1-cudnn8-devel-ubuntu20.04 pip3 install install jax[cuda11_cudnn82]==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

XLA_PYTHON_CLIENT_PREALLOCATE=false but nothing else special

idk single-GPU works, but i mean today pytorch dockerhub images work way way better and for multi-gpu yes

pwais avatar Aug 06 '22 21:08 pwais

+1 We really need a docker image running jax, no matter for deployment or development. It relly helps;)

ofey404 avatar Sep 30 '22 08:09 ofey404

Maybe this the sustainable answer for JAX on multi-node, multi-gpu. JAX Container Early Access

lukaemon avatar Oct 05 '22 02:10 lukaemon

NVIDIA official images are coming soon 😄 @pwais @lukaemon

mjsML avatar Oct 05 '22 12:10 mjsML

@pwais One possible cause for your multi-GPU problem is that you need to provide SHM space to your docker container for the use of NCCL. Try setting shm-size to something like 2GB.

hawkinsp avatar Oct 05 '22 12:10 hawkinsp

The NVIDIA JAX early access (EA) container is now released, register to get it :)

mjsML avatar Nov 02 '22 20:11 mjsML

This seems to be working but maybe I have made a mistake

FROM nvidia/cuda:11.7.0-devel-ubuntu20.04

# install python3-pip
RUN apt update && apt install python3-pip -y

RUN pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

nvidia/cuda:11.7.0-devel-ubuntu20.04 ships with cudnn 8.5 but "jax[cuda]" only works with cudnn 8.6+. The following works for me:

FROM nvidia/cuda:11.7.0-devel-ubuntu20.04

# install python3-pip
RUN apt update && apt install python3-pip -y

RUN pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Otherwise, use nvidia/cuda:11.8.0-devel-ubuntu20.04 which ships with cudnn 8.7

FROM nvidia/cuda:11.8.0-devel-ubuntu20.04

# install python3-pip
RUN apt update && apt install python3-pip -y

RUN pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

weiyaw avatar Mar 18 '23 14:03 weiyaw

What currently version of FROM nvidia/cuda: we could use with jax today?

bhack avatar Sep 04 '23 21:09 bhack

I guess nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 should work if you update the JAX/JAXlib version installed.

nouiz avatar Sep 05 '23 13:09 nouiz

None of Dockerfile configurations worked and I ended up using NVIDIA Jax containers:

docker pull ghcr.io/nvidia/jax:nightly-2023-09-09

pourmand1376 avatar Sep 10 '23 06:09 pourmand1376

JAX Toolbox includes the nightly container referenced above along with open Dockerfiles.

sbhavani avatar Sep 25 '23 20:09 sbhavani