graph_weather
graph_weather copied to clipboard
More efficent way of encoding input graph/output graph?
Currently, one of the issues with this implementation is that when there are a large amounts of input lat/lon coordinates (such as a 1 deg grid or smaller), the graphs describing the connections between the inputs and the latent graph become huge, and the model has a hard time fitting on a GPU, especially with any batch size larger than 1. It seems like there should be a better way of encoding the inputs into the latent graph than the way that I wrote in this repo, not sure how yet though.
Hi, thinking if to take this issue. The graph you're talking about is the Encoder.graph member?
Yes! That's the one. There have now been some other implementations of similar models that might be helpful, primarily Nvidia's GraphCast implementation here: https://github.com/NVIDIA/modulus/blob/main/modulus/models/graphcast/graph_cast_net.py