dalle-mini
dalle-mini copied to clipboard
Solves #260 by changing the Dockerfile to use a specified version of jaxlib
Specifies the exact version of jaxlib if it returns the following error:
#5 12.13 ERROR: No matching distribution found for jaxlib==0.3.10+cuda11.cudnn82 (from jax[cuda])
------
executor failed running [/bin/sh -c pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html && pip install -q git+https://github.com/borisdayma/dalle-mini.git git+https://github.com/patil-suraj/vqgan-jax.git]: exit code: 1
I think you could just use the dalle-mini package from pypi instead of installing from git. I'm not knowledgeable with docker so just waiting to see if there are more comments.
I think you could just use the dalle-mini package from pypi instead of installing from git. I'm not knowledgeable with docker so just waiting to see if there are more comments.
I will try this once I confirm that this fully works for the notebook
I was just following the tutorial that was posted on YouTube regarding WSL installs and ran into the issue, hence trying to solve it.
I did get it running properly by specifying the exact file but I'm not sure if that's a good or bad idea for the project overall. The only issue was that my original commit used a non-Cuda version of jaxlib so I am retrying with the cuda11 version (the version that cannot be found using the releases.html file.
for clarification I was following this video and the solution in this PR helped me solve getting stuck on the jax install issue on windows:
https://www.youtube.com/watch?v=OqEuEe-xSKk
(BTW Thanks @cdgamedev)
I think you could just use the dalle-mini package from pypi instead of installing from git. I'm not knowledgeable with docker so just waiting to see if there are more comments.
When changing the Dockerfile to do
RUN pip install dalle-mini
I run into this, which means that jax is running on my CPU and not GPU (nocuda) version. This is the same as running ./build_docker.sh --auto
, ./build_docker.sh --cpu
or ./build_docker.sh
on my own machine.
When I run the bash script as ./build_docker.sh --gpu
it will then give me:
The docker file tries searching the jax-releases.html file for a Cuda version - this works on some machines (native Linux) but apparently not on WSL - it's been a common issue I've noticed within Blake's server on discord and something that this resolves.
Whilst I don't think my current solution is the most elegant - and I will work on further improvements when I get chance - I think having the user choose between the three flags --auto|--cpu|--gpu
works as a solution to the current issue faced by many users on WSL.
The main issue I can foresee is it causes problems if the Python release is changed (from 3.8 to a later version as these URLS specify cp38).
I've never used Docker before and its a learning process that I am enjoying, @batrlatom and @mallorbc have previously edited the Dockerfile so I'd love to know what their take is on my changes (:
Edit:
Specifying the link will likely (untested) cause issues on MacOS. Working on updating this now, also changing flags to be --auto|--cuda|--nocuda
and using jax_cuda_releases.html instead of jax_releases.html for the --cuda
variant.
Have updated and tested the best I can with the hardware I have.
The only place I can't test is on MacOS and native Linux, but my change work on WSL for myself and others.
Not sure whether flags are a good idea overall and fit with the project but that's your own decision as it's your project.
If you feel this needs more overall, let me know and I'd be happy to make changes and improvements where you feel they are needed (:
Need to have jax with Cuda for GPU support is my understanding. Pip installing does not support GPUs it seems and thus takes MUCH longer to generate any image. Whatever the final solution is, I think easily supporting GPUs in the docker container should be done.
https://storage.googleapis.com/jax-releases/jax_releases.html to https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
should now be used instead of the old url
Its worth noting that some forms of virtualization will have issues with jax due to not supporting certain instructions. I had some issues with that.
Is #272 related to this PR?
https://storage.googleapis.com/jax-releases/jax_releases.html to https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
should now be used instead of the old url
This worked for me thank you man as between pip, conda and conda forge+cuda prefixes are driving me nuts. Basically was getting NO GPU switching to CPU message and this fixed it.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Thanks... Rebuilt the whole environment and started over just to see if there was a compability issue here but I get a "Killed" after the dalle mega model downloads, on my 3090 it was working a couple weeks back. Probably OOM memory but 24GB should be fine, only 300mb allocated before I run it. :/