graph_weather icon indicating copy to clipboard operation
graph_weather copied to clipboard

Add TPU support

Open jacobbieker opened this issue 2 years ago • 6 comments

Detailed Description

Currently, the graph neural network library dependencies don't support TPUs with pytorch geometric, or don't seem to at least because of custom kernels. We could add a Jax version for TPU support? The original model was implemented in Jax apparently.

Context

Being able to use TPUs could speed up training quite a bit.

Possible Implementation

jacobbieker avatar Jun 16 '22 14:06 jacobbieker

Hey @jacobbieker, I wanted to start contributing to openclimatefix and take a shot at this if this feature is still of interest.

vballoli avatar Oct 13 '22 04:10 vballoli

Hi, yeah that would be awesome!

jacobbieker avatar Oct 13 '22 13:10 jacobbieker

Is this still of interest? Thinking of working on it over a long time frame.
The bottleneck would be getting the graph block working but I managed to get some of the MLP components functioning. Using Flax(Linnen) + Jraph.

aavashsubedi avatar Mar 12 '24 10:03 aavashsubedi

@jacobbieker

aavashsubedi avatar Mar 12 '24 11:03 aavashsubedi

Yeah, this is still of interest! It would be good to be able to use this on TPUs. Ideally still in PyTorch, although up for JAX as well.

jacobbieker avatar Mar 12 '24 11:03 jacobbieker

Hi, Amazing! I wasn't aware that you can use XLA on PyTorch till just now! : ) For inference I guess this would be the easiest thing:

import torch_xla.utils.serialization as xser

model.load_state_dict(xser.load('model.pt'))

https://stackoverflow.com/questions/69328983/are-pytorch-trained-models-transferable-between-gpus-and-tpus So there's less mismatch between versions, will stick to PyTorch (maybe there's a cool way to wrap TPU support around each GPU model will see! ). WIll read up on XLA : )

aavashsubedi avatar Mar 12 '24 11:03 aavashsubedi