dalle-playground icon indicating copy to clipboard operation
dalle-playground copied to clipboard

DALL-E server doesn't work with AMD GPU

Open SyntaxOutlaw opened this issue 3 years ago • 7 comments

Running on Ubuntu 20.4 LTS native, with AMD RX 6600 XT GPU. Local development - no docker

GPU works with OpenGL as I can run glmark2 and Blender using OpenGL and my GPU, but can't seem to run with DALL-E server.

$ python3 app.py 8080 --> Starting DALL-E Server. This might take up to two minutes. "hipErrorNoBinaryForGpu: Unable to find code object for all current devices!" Aborted (core dumped)

Any help would be great!

SyntaxOutlaw avatar Jun 12 '22 12:06 SyntaxOutlaw

This requires cuda which is only available on NVidia GPUs

Codel1417 avatar Jun 12 '22 17:06 Codel1417

Any plans to support AMD GPUs?

SyntaxOutlaw avatar Jun 12 '22 17:06 SyntaxOutlaw

This requires cuda which is only available on NVidia GPUs

I believe it requires Pytorch, which has experimental support for AMD gpus via ROCm : https://pytorch.org/get-started/locally/ when installing pytorch through pip3, you can verify pytorch can use the gpu by:

$ python3
>>>import pytorch
>>>torch.cuda.is_available()
true

with my 6900xt, i was able to get pytorch to recognize my gpu. however i can't get this app to recognize it, i also get the same error:

$ python3 app.py 20000
--> Starting DALL-E Server. This might take up to two minutes.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

any info would be appreciated. im on LTS ubuntu 20.04,4, on the 5.13.0-48-generic kernel.

general hw info: ryzen7 5800x 32gb ram 6900xt also have gtx 1060 because maybe ROCm is only available/ intended for headless setups so thought maybe i could use the nvidia gpu for video and the amd gpu just for compute. its currently being unused bc pytorch still seems to recognize the 6900xt.

update: so it appears to be an issue with Jax and/or tensorflow, by downloading the latest version of jax and building it according to the instructions here (https://github.com/google/jax/issues/2012#issuecomment-738896364) to enable rocm support and install it, i was able to get this app to no longer display the "No GPU/TPU found" error, but the backend will crash on startup. I believe it maybe caused by conflicting rocm versions (pytorch uses older one). im not sure, im quite new to all of this tbh.

osimmac avatar Jun 12 '22 23:06 osimmac

I have it working on a 6800xt without issues

rocm works fine on full desktop, no need for headless. Using rocm 5.1.0 works correctly, need to build jaxlib from source with the --enable_rocm flag, and also specifying the path to the correct rocm install with --rocm_path=/opt/rocm-5.1.0 .

repo for jaxlib: https://github.com/google/jax

rocm installer guide: https://docs.amd.com/bundle/ROCm-Installation-Guide-v5.1/page/How_to_Install_ROCm.html#_Installation_Methods

after that, the regular install process for dalle works fine, just ensure it doesnt attempt to install a non-rocm jaxlib over your build!

Should note i had to make sure the jax repository was checked out to the latest tagged commit (jax-v0.3.13 at time of writing), just cloning the repo doesnt work correctly.

system specs for reference:

Ryzen 7 5800x3d
32gb ram
6800xt
ubuntu LTS 20.04  kernel 5.13.0-51-generic

NevesLucas avatar Jun 18 '22 23:06 NevesLucas

FWIW, I don't think this project uses pytorch at all, not sure why it says it does. I compiled jaxlib for rocm, but I get the following error: 2022-06-29 04:45:34.924212: I external/org_tensorflow/tensorflow/compiler/xla/service/platform_util.cc:145] StreamExecutor ROCM device (0) is of unsupported AMDGPU version : gfx1010. The supported AMDGPU versions are gfx1030, gfx900, gfx906, gfx908, gfx90a.

I edited tensorflow/jaxlib to allow my GPU here: https://gist.github.com/DarkShadow44/60decf1ae76cd1143479620193b53ebe

Now I get ImportError: /usr/lib/python3.10/site-packages/jaxlib/xla_extension.so: undefined symbol: _ZN10tensorflow33tensor_float_32_execution_enabledEv

Not sure if that's because of my unsupported GPU or because I compiled it wrong somehow?

DarkShadow44 avatar Jun 29 '22 03:06 DarkShadow44

@DarkShadow44 do you have an rx5700xt? You might need to do a full clean of the bazel cache if you previously tried building with the wrong device flags. Where in the original source tree are the files you edited?

the arguments i used to build jax for the rx6800xt were:

python3 build/build.py  --enable_rocm --rocm_path=/opt/rocm-5.1.0

pip3 install -e .

I think jax's default device support list omits gfx1010, so you can explicitly set it with:

python3 build/build.py  --enable_rocm --rocm_path=/opt/rocm-5.1.0 --rocm_amdgpu_target "gfx1010"

pip3 install -e .

make sure to git checkout jaxlib-v0.3.14 to get the latest tagged release

NevesLucas avatar Jun 30 '22 05:06 NevesLucas

I have a normal RX 5700, no XT. I edited "tensorflow/stream_executor/device_description.h" to add a gfx1010, just adding rocm_amdgpu_target doesn't work.

DarkShadow44 avatar Jun 30 '22 22:06 DarkShadow44

@NevesLucas I would like to try on 6900xt, do you have a Dockerfile for AMD?

exander77 avatar Sep 29 '22 13:09 exander77

@exander77 hi sorry I never put it in a docker environment, I did a native install to my Linux system

NevesLucas avatar Nov 20 '22 20:11 NevesLucas

Closing as @NevesLucas seems to have provided a working solution, and thread has gone stale.

Thanks to everyone for their input!

SyntaxOutlaw avatar Jul 09 '23 18:07 SyntaxOutlaw