physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

GraphCast: Fine-tuning (Impact: moderate, effort: high)

Open mnabian opened this issue 1 year ago • 4 comments

mnabian avatar Jun 18 '24 18:06 mnabian

Hello @mnabian,

Thank you for your excellent implementation using modulus! I was wondering if it's possible to inherit the pretrained checkpoint from the main JAX model. I'm exploring alternatives to leverage the pre-trained model for finetuning in this context.

Have there been any attempts to convert JAX model weights for use in the torch modulus implementation? Thanks in advance!

Best, Koublal

oublalkhalid avatar Jan 04 '25 14:01 oublalkhalid

Looking forward to further progress on the GraphCast model, or demonstration of training results using Modulus Vs DeepMind jax.

Flionay avatar Jan 09 '25 07:01 Flionay

Hello @mnabian,

Getting back to you after a few days of improvements. To be honest, I found that DALI’s iterator is not very efficient, so I implemented an iterable dataset directly, Please find below my benchmarking results comparing different dataloader strategies. It seems that the get_item method might not be optimized for handling the large size of the data. I’ll also suggest some modifications to the validation setup in my PR as loop goes to infinity for some step. My current version, trained on five years of data (2013-2017), achieves performance comparable to the JAX-based implementation from GraphCast (DeepMind Weather Team)!

Image

I look forward to your feedback on my PR!

Best, khalid

oublalkhalid avatar Apr 25 '25 23:04 oublalkhalid

Hi @oublalkhalid, Thank you for your graph and explanation, it looks awesome!!! Did you train it with static channels? I have been trying to train it with static channels, but it crashes in the fine-tuning part. What we are running into is that during multi step rollout for finetuning, the size of our predicted vector (just the weather channels) is not the size of the input vector (static and weather). During the rollout, the predicted output is fed back in directly without reappending the static channels, which leads to model-mismatch errors. Then it crashes if static channel count > 0, use_time_of_year_index: or use_cos_zenith: is true

Did you encounter the same error? or am I doing something wrong while launching the training? Thanks in advance for your help ;) Kind regards, Julio

Betancourt20 avatar Nov 05 '25 11:11 Betancourt20