jax
jax copied to clipboard
rocm 5.7.1 + 7900 xtx + jax:latest docker image not working
hardware: 7900xtx ubuntu 22.04 lts rocm 5.7.1 (first version w/ official 7900xtx support)
I am able to run the pytorch image, eg:
drun rocm/pytorch
root@minerva:~# python3
Python 3.9.18 (main, Sep 11 2023, 13:41:44)
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.device(0)
device(type='cuda', index=0)
>>> torch.cuda.get_device_properties(0).total_memory
25753026560
and have trained simple CNN models using this setup.
However, when I run the jax version of the same image I get the following error on launch of the image:
drun rocm/jax:rocm5.7.0-jax0.4.20-py3.11.0
root@minerva:/root# python3
Python 3.11.0 (main, Nov 16 2023, 20:45:15) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2023-11-29 19:56:58.717575: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
>>> jax.devices()
[CpuDevice(id=0)]
Is there something I should be doing differently? Is this configuration officially supported?
same for the latest image:
drun rocm/jax:latest
[sudo] password for skoonce:
root@minerva:/root# python3
Python 3.10.0 (default, Nov 16 2023, 22:24:12) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2023-11-30 12:48:26.684744: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
Per discussion in https://github.com/google/jax/issues/7598#issuecomment-1834072178, waiting for Rocm 6.0!
Updated to ROCm 6.0. Pytorch image working!
Still getting this with rocm/jax:latest
drun rocm/jax
root@minerva:/root# python3
Python 3.10.0 (default, Dec 12 2023, 19:54:24) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2023-12-16 22:23:14.261510: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
Updated to ROCm 6.0. Pytorch image working!
Still getting this with rocm/jax:latest
drun rocm/jax root@minerva:/root# python3 Python 3.10.0 (default, Dec 12 2023, 19:54:24) [GCC 9.4.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import jax 2023-12-16 22:23:14.261510: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
used rocm/dev-ubuntu-22.04:latest, rocm 6.0.0, compiled jax from source branch v0.4.23 (with rocm/xla repo same branch) no dice, same error, running gpu tests via bazel crashed the system.. it would be nice to get jax working.......
edit: gfx1100 is simply not supported yet
In another HPC (France's Adastra, MI250) I managed to get jax running under Singularity.
I get the same error ås you
2024-01-09 00:35:03.802451: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
but rocm does work
>>> jax.devices()
[rocm(id=0), rocm(id=1), rocm(id=2), rocm(id=3), rocm(id=4), rocm(id=5), rocm(id=6), rocm(id=7)]
>>> vs.samples.sharding
SingleDeviceSharding(device=rocm(id=0))
>>>
So now I wonder what this error does mean...
default="gfx900,gfx906,gfx908,gfx90a,gfx1030" i'm thinking new gpus aren't supported yet and fall back to cpu? so yes it will work but i don't think it's using your mi250? the error means it couldn't load the dnn plugin meaning no gpu support
I think it is running calculations on GPU (also judging on rutime of the code, it matches that of similar Nvidia GPUs, and surely ain't cpu).
I think it is running calculations on GPU (also judging on rutime of the code, it matches that of similar Nvidia GPUs, and surely ain't cpu).
oh ok, well i just have a consumer rx 7900 xtx that has rdna3 which is not supported i think, i'm gonna check again. gfx90a is cdna2 so mi250 is supported
I am specifically asking about CDNA3, which requires upstream support!
In a MI100(gfx908)system, Ubuntu 22.04, ROCm 6.0.2 Compiled jax v0.4.24 from rocm/jax with rocm/xla, I got the same error
2024-03-14 02:07:55.868414: E external/xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: DNN
However, rocm does work(with significantly low performance)
>>> jax.devices()
[rocm(id=0)]
See also https://github.com/google/jax/issues/19453#issuecomment-1927543572!
In another HPC (France's Adastra, MI250) I managed to get jax running under Singularity.
I get the same error ås you
2024-01-09 00:35:03.802451: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNNbut rocm does work
>>> jax.devices() [rocm(id=0), rocm(id=1), rocm(id=2), rocm(id=3), rocm(id=4), rocm(id=5), rocm(id=6), rocm(id=7)] >>> vs.samples.sharding SingleDeviceSharding(device=rocm(id=0)) >>>So now I wonder what this error does mean...
https://github.com/openxla/xla/pull/10954 fixes that warning/ error. Although, it had no functional implications, it was annoying nonetheless.
@rahulbatra85 With ROCm 6.1, I am able to train basic convolutional networks on my 7900 xtx! I will run some more advanced networks and report back any issues that I find!