How to run graphcast on AMD GPU?
Hello,
I have a AMD GPU like MI50, is there any way to run graphcast using ROCm library?
Hey @BigShuiTai - We're currently looking into this as well . At the moment, we're starting here with Jax on ROCm . I'll keep you posted on what we find :)
@BigShuiTai - this is awesome. Curious why you're setting PYTORCH_ROCM_ARCH. My understanding is that pytorch is not a dependency here.
@garrettbyrd check this out
@BigShuiTai - this is awesome. Curious why you're setting
PYTORCH_ROCM_ARCH. My understanding is that pytorch is not a dependency here.@garrettbyrd check this out
It's from my settings using for another models, don't mind.
@BigShuiTai What does your output look like for this solution? Am I correct in using dataset/source-era5_date-2022-01-01_res-0.25_levels-13_steps-01.nc as the input? During runs I keep getting floating point errors that lead to NaNs in the solution. E.g.,
E0507 15:48:58.498964 3333925 buffer_comparator.cc:156] Difference at 32: -0.371094, expected 388
E0507 15:48:58.499005 3333925 buffer_comparator.cc:156] Difference at 33: -0.371094, expected 388
E0507 15:48:58.499009 3333925 buffer_comparator.cc:156] Difference at 34: -0.371094, expected 386
E0507 15:48:58.499012 3333925 buffer_comparator.cc:156] Difference at 35: -0.371094, expected 392
E0507 15:48:58.499015 3333925 buffer_comparator.cc:156] Difference at 36: -0.371094, expected 382
E0507 15:48:58.499018 3333925 buffer_comparator.cc:156] Difference at 37: -0.371094, expected 382
E0507 15:48:58.499021 3333925 buffer_comparator.cc:156] Difference at 38: -0.371094, expected 386
E0507 15:48:58.499025 3333925 buffer_comparator.cc:156] Difference at 39: -0.371094, expected 388
E0507 15:48:58.499028 3333925 buffer_comparator.cc:156] Difference at 40: -0.371094, expected 388
E0507 15:48:58.499033 3333925 buffer_comparator.cc:156] Difference at 41: -0.371094, expected 392
Are you experiencing anything similar?
@BigShuiTai What does your output look like for this solution? Am I correct in using
dataset/source-era5_date-2022-01-01_res-0.25_levels-13_steps-01.ncas the input? During runs I keep getting floating point errors that lead to NaNs in the solution. E.g.,E0507 15:48:58.498964 3333925 buffer_comparator.cc:156] Difference at 32: -0.371094, expected 388 E0507 15:48:58.499005 3333925 buffer_comparator.cc:156] Difference at 33: -0.371094, expected 388 E0507 15:48:58.499009 3333925 buffer_comparator.cc:156] Difference at 34: -0.371094, expected 386 E0507 15:48:58.499012 3333925 buffer_comparator.cc:156] Difference at 35: -0.371094, expected 392 E0507 15:48:58.499015 3333925 buffer_comparator.cc:156] Difference at 36: -0.371094, expected 382 E0507 15:48:58.499018 3333925 buffer_comparator.cc:156] Difference at 37: -0.371094, expected 382 E0507 15:48:58.499021 3333925 buffer_comparator.cc:156] Difference at 38: -0.371094, expected 386 E0507 15:48:58.499025 3333925 buffer_comparator.cc:156] Difference at 39: -0.371094, expected 388 E0507 15:48:58.499028 3333925 buffer_comparator.cc:156] Difference at 40: -0.371094, expected 388 E0507 15:48:58.499033 3333925 buffer_comparator.cc:156] Difference at 41: -0.371094, expected 392Are you experiencing anything similar?
Hi @garrettbyrd, I don't meet this issue before, but I think it's not a bug to jax. Can you give me messages, e.g. GPU device or your inference code?
I am running on an MI210 with ROCm 6.3.1. Could you provide which ROCm version you're using? I am running the inference script you provided.
I am running on an MI210 with ROCm 6.3.1. Could you provide which ROCm version you're using? I am running the inference script you provided.
I'm running GraphCast on an MI50 with ROCm 6.3.3.