jax icon indicating copy to clipboard operation
jax copied to clipboard

rocm 5.7.1 + 7900 xtx + jax:latest docker image not working

Open brettkoonce opened this issue 1 year ago • 11 comments

hardware: 7900xtx ubuntu 22.04 lts rocm 5.7.1 (first version w/ official 7900xtx support)

I am able to run the pytorch image, eg:

drun rocm/pytorch
root@minerva:~# python3
Python 3.9.18 (main, Sep 11 2023, 13:41:44) 
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.device(0)
device(type='cuda', index=0)
>>> torch.cuda.get_device_properties(0).total_memory
25753026560

and have trained simple CNN models using this setup.

However, when I run the jax version of the same image I get the following error on launch of the image:

drun rocm/jax:rocm5.7.0-jax0.4.20-py3.11.0
root@minerva:/root# python3
Python 3.11.0 (main, Nov 16 2023, 20:45:15) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2023-11-29 19:56:58.717575: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
>>> jax.devices()
[CpuDevice(id=0)]

Is there something I should be doing differently? Is this configuration officially supported?

brettkoonce avatar Nov 30 '23 12:11 brettkoonce

same for the latest image:

drun rocm/jax:latest
[sudo] password for skoonce: 
root@minerva:/root# python3
Python 3.10.0 (default, Nov 16 2023, 22:24:12) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2023-11-30 12:48:26.684744: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN

brettkoonce avatar Nov 30 '23 12:11 brettkoonce

Per discussion in https://github.com/google/jax/issues/7598#issuecomment-1834072178, waiting for Rocm 6.0!

brettkoonce avatar Dec 02 '23 16:12 brettkoonce

Updated to ROCm 6.0. Pytorch image working!

Still getting this with rocm/jax:latest

drun rocm/jax
root@minerva:/root# python3
Python 3.10.0 (default, Dec 12 2023, 19:54:24) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2023-12-16 22:23:14.261510: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN

brettkoonce avatar Dec 16 '23 22:12 brettkoonce

Updated to ROCm 6.0. Pytorch image working!

Still getting this with rocm/jax:latest

drun rocm/jax
root@minerva:/root# python3
Python 3.10.0 (default, Dec 12 2023, 19:54:24) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2023-12-16 22:23:14.261510: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN

used rocm/dev-ubuntu-22.04:latest, rocm 6.0.0, compiled jax from source branch v0.4.23 (with rocm/xla repo same branch) no dice, same error, running gpu tests via bazel crashed the system.. it would be nice to get jax working.......

edit: gfx1100 is simply not supported yet

mjolk avatar Jan 06 '24 12:01 mjolk

In another HPC (France's Adastra, MI250) I managed to get jax running under Singularity.

I get the same error ås you

2024-01-09 00:35:03.802451: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN

but rocm does work

>>> jax.devices()
[rocm(id=0), rocm(id=1), rocm(id=2), rocm(id=3), rocm(id=4), rocm(id=5), rocm(id=6), rocm(id=7)]
>>> vs.samples.sharding
SingleDeviceSharding(device=rocm(id=0))
>>>

So now I wonder what this error does mean...

PhilipVinc avatar Jan 08 '24 23:01 PhilipVinc

default="gfx900,gfx906,gfx908,gfx90a,gfx1030" i'm thinking new gpus aren't supported yet and fall back to cpu? so yes it will work but i don't think it's using your mi250? the error means it couldn't load the dnn plugin meaning no gpu support

mjolk avatar Jan 13 '24 12:01 mjolk

I think it is running calculations on GPU (also judging on rutime of the code, it matches that of similar Nvidia GPUs, and surely ain't cpu).

PhilipVinc avatar Jan 13 '24 16:01 PhilipVinc

I think it is running calculations on GPU (also judging on rutime of the code, it matches that of similar Nvidia GPUs, and surely ain't cpu).

oh ok, well i just have a consumer rx 7900 xtx that has rdna3 which is not supported i think, i'm gonna check again. gfx90a is cdna2 so mi250 is supported

mjolk avatar Jan 13 '24 16:01 mjolk

I am specifically asking about CDNA3, which requires upstream support!

brettkoonce avatar Feb 13 '24 16:02 brettkoonce

In a MI100(gfx908)system, Ubuntu 22.04, ROCm 6.0.2 Compiled jax v0.4.24 from rocm/jax with rocm/xla, I got the same error

2024-03-14 02:07:55.868414: E external/xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: DNN

However, rocm does work(with significantly low performance)

>>> jax.devices()
[rocm(id=0)]

allegro0132 avatar Mar 14 '24 02:03 allegro0132

See also https://github.com/google/jax/issues/19453#issuecomment-1927543572!

brettkoonce avatar Mar 19 '24 17:03 brettkoonce

In another HPC (France's Adastra, MI250) I managed to get jax running under Singularity.

I get the same error ås you

2024-01-09 00:35:03.802451: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN

but rocm does work

>>> jax.devices()
[rocm(id=0), rocm(id=1), rocm(id=2), rocm(id=3), rocm(id=4), rocm(id=5), rocm(id=6), rocm(id=7)]
>>> vs.samples.sharding
SingleDeviceSharding(device=rocm(id=0))
>>>

So now I wonder what this error does mean...

https://github.com/openxla/xla/pull/10954 fixes that warning/ error. Although, it had no functional implications, it was annoying nonetheless.

Ruturaj4 avatar Mar 27 '24 00:03 Ruturaj4

@rahulbatra85 With ROCm 6.1, I am able to train basic convolutional networks on my 7900 xtx! I will run some more advanced networks and report back any issues that I find!

brettkoonce avatar Apr 20 '24 01:04 brettkoonce