graphcast icon indicating copy to clipboard operation
graphcast copied to clipboard

How to run graphcast on AMD GPU?

Open BigShuiTai opened this issue 8 months ago • 8 comments

Hello,

I have a AMD GPU like MI50, is there any way to run graphcast using ROCm library?

BigShuiTai avatar Apr 12 '25 15:04 BigShuiTai

Hey @BigShuiTai - We're currently looking into this as well . At the moment, we're starting here with Jax on ROCm . I'll keep you posted on what we find :)

fluidnumerics-joe avatar Apr 14 '25 19:04 fluidnumerics-joe

Hi @fluidnumerics-joe,

I have resolved this issue, you can check out the method on my fork :)

BigShuiTai avatar May 02 '25 09:05 BigShuiTai

@BigShuiTai - this is awesome. Curious why you're setting PYTORCH_ROCM_ARCH. My understanding is that pytorch is not a dependency here.

@garrettbyrd check this out

fluidnumerics-joe avatar May 02 '25 12:05 fluidnumerics-joe

@BigShuiTai - this is awesome. Curious why you're setting PYTORCH_ROCM_ARCH. My understanding is that pytorch is not a dependency here.

@garrettbyrd check this out

It's from my settings using for another models, don't mind.

BigShuiTai avatar May 02 '25 12:05 BigShuiTai

@BigShuiTai What does your output look like for this solution? Am I correct in using dataset/source-era5_date-2022-01-01_res-0.25_levels-13_steps-01.nc as the input? During runs I keep getting floating point errors that lead to NaNs in the solution. E.g.,

E0507 15:48:58.498964 3333925 buffer_comparator.cc:156] Difference at 32: -0.371094, expected 388
E0507 15:48:58.499005 3333925 buffer_comparator.cc:156] Difference at 33: -0.371094, expected 388
E0507 15:48:58.499009 3333925 buffer_comparator.cc:156] Difference at 34: -0.371094, expected 386
E0507 15:48:58.499012 3333925 buffer_comparator.cc:156] Difference at 35: -0.371094, expected 392
E0507 15:48:58.499015 3333925 buffer_comparator.cc:156] Difference at 36: -0.371094, expected 382
E0507 15:48:58.499018 3333925 buffer_comparator.cc:156] Difference at 37: -0.371094, expected 382
E0507 15:48:58.499021 3333925 buffer_comparator.cc:156] Difference at 38: -0.371094, expected 386
E0507 15:48:58.499025 3333925 buffer_comparator.cc:156] Difference at 39: -0.371094, expected 388
E0507 15:48:58.499028 3333925 buffer_comparator.cc:156] Difference at 40: -0.371094, expected 388
E0507 15:48:58.499033 3333925 buffer_comparator.cc:156] Difference at 41: -0.371094, expected 392

Are you experiencing anything similar?

garrettbyrd avatar May 07 '25 20:05 garrettbyrd

@BigShuiTai What does your output look like for this solution? Am I correct in using dataset/source-era5_date-2022-01-01_res-0.25_levels-13_steps-01.nc as the input? During runs I keep getting floating point errors that lead to NaNs in the solution. E.g.,

E0507 15:48:58.498964 3333925 buffer_comparator.cc:156] Difference at 32: -0.371094, expected 388
E0507 15:48:58.499005 3333925 buffer_comparator.cc:156] Difference at 33: -0.371094, expected 388
E0507 15:48:58.499009 3333925 buffer_comparator.cc:156] Difference at 34: -0.371094, expected 386
E0507 15:48:58.499012 3333925 buffer_comparator.cc:156] Difference at 35: -0.371094, expected 392
E0507 15:48:58.499015 3333925 buffer_comparator.cc:156] Difference at 36: -0.371094, expected 382
E0507 15:48:58.499018 3333925 buffer_comparator.cc:156] Difference at 37: -0.371094, expected 382
E0507 15:48:58.499021 3333925 buffer_comparator.cc:156] Difference at 38: -0.371094, expected 386
E0507 15:48:58.499025 3333925 buffer_comparator.cc:156] Difference at 39: -0.371094, expected 388
E0507 15:48:58.499028 3333925 buffer_comparator.cc:156] Difference at 40: -0.371094, expected 388
E0507 15:48:58.499033 3333925 buffer_comparator.cc:156] Difference at 41: -0.371094, expected 392

Are you experiencing anything similar?

Hi @garrettbyrd, I don't meet this issue before, but I think it's not a bug to jax. Can you give me messages, e.g. GPU device or your inference code?

BigShuiTai avatar May 07 '25 23:05 BigShuiTai

I am running on an MI210 with ROCm 6.3.1. Could you provide which ROCm version you're using? I am running the inference script you provided.

garrettbyrd avatar May 08 '25 14:05 garrettbyrd

I am running on an MI210 with ROCm 6.3.1. Could you provide which ROCm version you're using? I am running the inference script you provided.

I'm running GraphCast on an MI50 with ROCm 6.3.3.

BigShuiTai avatar May 08 '25 14:05 BigShuiTai