graph_weather
graph_weather copied to clipboard
Add TPU support
Detailed Description
Currently, the graph neural network library dependencies don't support TPUs with pytorch geometric, or don't seem to at least because of custom kernels. We could add a Jax version for TPU support? The original model was implemented in Jax apparently.
Context
Being able to use TPUs could speed up training quite a bit.
Possible Implementation
Hey @jacobbieker, I wanted to start contributing to openclimatefix and take a shot at this if this feature is still of interest.
Hi, yeah that would be awesome!
Is this still of interest? Thinking of working on it over a long time frame.
The bottleneck would be getting the graph block working but I managed to get some of the MLP components functioning.
Using Flax(Linnen) + Jraph.
@jacobbieker
Yeah, this is still of interest! It would be good to be able to use this on TPUs. Ideally still in PyTorch, although up for JAX as well.
Hi, Amazing! I wasn't aware that you can use XLA on PyTorch till just now! : ) For inference I guess this would be the easiest thing:
import torch_xla.utils.serialization as xser
model.load_state_dict(xser.load('model.pt'))
https://stackoverflow.com/questions/69328983/are-pytorch-trained-models-transferable-between-gpus-and-tpus So there's less mismatch between versions, will stick to PyTorch (maybe there's a cool way to wrap TPU support around each GPU model will see! ). WIll read up on XLA : )