dalle-playground
dalle-playground copied to clipboard
DALL-E server doesn't work with AMD GPU
Running on Ubuntu 20.4 LTS native, with AMD RX 6600 XT GPU. Local development - no docker
GPU works with OpenGL as I can run glmark2 and Blender using OpenGL and my GPU, but can't seem to run with DALL-E server.
$ python3 app.py 8080 --> Starting DALL-E Server. This might take up to two minutes. "hipErrorNoBinaryForGpu: Unable to find code object for all current devices!" Aborted (core dumped)
Any help would be great!
This requires cuda which is only available on NVidia GPUs
Any plans to support AMD GPUs?
This requires cuda which is only available on NVidia GPUs
I believe it requires Pytorch, which has experimental support for AMD gpus via ROCm : https://pytorch.org/get-started/locally/ when installing pytorch through pip3, you can verify pytorch can use the gpu by:
$ python3
>>>import pytorch
>>>torch.cuda.is_available()
true
with my 6900xt, i was able to get pytorch to recognize my gpu. however i can't get this app to recognize it, i also get the same error:
$ python3 app.py 20000
--> Starting DALL-E Server. This might take up to two minutes.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
any info would be appreciated. im on LTS ubuntu 20.04,4, on the 5.13.0-48-generic kernel.
general hw info: ryzen7 5800x 32gb ram 6900xt also have gtx 1060 because maybe ROCm is only available/ intended for headless setups so thought maybe i could use the nvidia gpu for video and the amd gpu just for compute. its currently being unused bc pytorch still seems to recognize the 6900xt.
update: so it appears to be an issue with Jax and/or tensorflow, by downloading the latest version of jax and building it according to the instructions here (https://github.com/google/jax/issues/2012#issuecomment-738896364) to enable rocm support and install it, i was able to get this app to no longer display the "No GPU/TPU found" error, but the backend will crash on startup. I believe it maybe caused by conflicting rocm versions (pytorch uses older one). im not sure, im quite new to all of this tbh.
I have it working on a 6800xt without issues
rocm works fine on full desktop, no need for headless. Using rocm 5.1.0 works correctly, need to build jaxlib from source with the --enable_rocm flag, and also specifying the path to the correct rocm install with --rocm_path=/opt/rocm-5.1.0 .
repo for jaxlib: https://github.com/google/jax
rocm installer guide: https://docs.amd.com/bundle/ROCm-Installation-Guide-v5.1/page/How_to_Install_ROCm.html#_Installation_Methods
after that, the regular install process for dalle works fine, just ensure it doesnt attempt to install a non-rocm jaxlib over your build!
Should note i had to make sure the jax repository was checked out to the latest tagged commit (jax-v0.3.13 at time of writing), just cloning the repo doesnt work correctly.
system specs for reference:
Ryzen 7 5800x3d
32gb ram
6800xt
ubuntu LTS 20.04 kernel 5.13.0-51-generic
FWIW, I don't think this project uses pytorch at all, not sure why it says it does.
I compiled jaxlib for rocm, but I get the following error:
2022-06-29 04:45:34.924212: I external/org_tensorflow/tensorflow/compiler/xla/service/platform_util.cc:145] StreamExecutor ROCM device (0) is of unsupported AMDGPU version : gfx1010. The supported AMDGPU versions are gfx1030, gfx900, gfx906, gfx908, gfx90a.
I edited tensorflow/jaxlib to allow my GPU here: https://gist.github.com/DarkShadow44/60decf1ae76cd1143479620193b53ebe
Now I get ImportError: /usr/lib/python3.10/site-packages/jaxlib/xla_extension.so: undefined symbol: _ZN10tensorflow33tensor_float_32_execution_enabledEv
Not sure if that's because of my unsupported GPU or because I compiled it wrong somehow?
@DarkShadow44 do you have an rx5700xt? You might need to do a full clean of the bazel cache if you previously tried building with the wrong device flags. Where in the original source tree are the files you edited?
the arguments i used to build jax for the rx6800xt were:
python3 build/build.py --enable_rocm --rocm_path=/opt/rocm-5.1.0
pip3 install -e .
I think jax's default device support list omits gfx1010, so you can explicitly set it with:
python3 build/build.py --enable_rocm --rocm_path=/opt/rocm-5.1.0 --rocm_amdgpu_target "gfx1010"
pip3 install -e .
make sure to git checkout jaxlib-v0.3.14 to get the latest tagged release
I have a normal RX 5700, no XT. I edited "tensorflow/stream_executor/device_description.h" to add a gfx1010, just adding rocm_amdgpu_target doesn't work.
@NevesLucas I would like to try on 6900xt, do you have a Dockerfile for AMD?
@exander77 hi sorry I never put it in a docker environment, I did a native install to my Linux system
Closing as @NevesLucas seems to have provided a working solution, and thread has gone stale.
Thanks to everyone for their input!