graphcast
graphcast copied to clipboard
How to fine-tune graphcast?
Is there any sample code for fine-tuning graphcast?
I also want to know. I can run the notebook code, get loss and grads, but I find that the loss doesn't backprop. By the way, how to save model randomly initialized is also not mentioned.
I managed to get it working with a few modifications to the example notebook code. Here are the changes I made (ignoring all unchanged code from the example):
import optax
# modify the gradients function signature
def grads_fn(params, state, inputs, targets, forcings, model_config, task_config):
def _aux(params, state, i, t, f):
(loss, diagnostics), next_state = loss_fn.apply(params, state, jax.random.PRNGKey(0), model_config, task_config, i, t, f)
return loss, (diagnostics, next_state)
(loss, (diagnostics, next_state)), grads = jax.value_and_grad(_aux, has_aux=True)(params, state, inputs, targets, forcings)
return loss, diagnostics, next_state, grads
# remove `with_params` from jitted grads function
grads_fn_jitted = jax.jit(with_configs(grads_fn))
# setup optimiser
lr = 1e-3
optimiser = optax.adam(lr, b1=0.9, b2=0.999, eps=1e-8)
opt_state = optimiser.init(params)
# calculate loss and gradients
loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, inputs, targets, forcings)
# update
updates, opt_state = optimiser.update(grads, opt_state)
params = optax.apply_updates(params, updates)
I'm afraid we don't provide a training script in this codebase as our training setup is quite tied to internal infrastructure, but I think we give enough to construct one yourself if you want one, and the above is a good start. Note you may struggle to fine-tune the 0.25deg model unrolled to 3 days without extra tricks to save GPU/TPU RAM (see other issue on this).
I'll leave this open in case any others want to contribute example code.
For saving and loading the model parameters, I came up with these functions. @illuSION-crypto
import jax
import numpy as np
import jax.numpy as jnp
import os
def flatten_dict(d, parent_key='', sep='//'):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def save_model_params(d, file_path):
flat_dict = flatten_dict(d)
# Convert JAX arrays to NumPy for saving
np_dict = {k: np.array(v) if isinstance(v, jnp.ndarray) else v for k, v in flat_dict.items()}
np.savez(file_path, **np_dict)
params_path = os.path.join('path/to/params', 'params.npz')
save_model_params(params, params_path)
def unflatten_dict(d, sep='//'):
result_dict = {}
for flat_key, value in d.items():
keys = flat_key.split(sep)
d = result_dict
for key in keys[:-1]:
if key not in d:
d[key] = {}
d = d[key]
d[keys[-1]] = value
return result_dict
def load_model_params(file_path):
with np.load(file_path, allow_pickle=True) as npz_file:
# Convert NumPy arrays back to JAX arrays
jax_dict = {k: jnp.array(v) for k, v in npz_file.items()}
return unflatten_dict(jax_dict)
params = load_model_params(params_path)
Thanks @ChrisAGBlake for sharing! Is it possible to modify your script to allow for multiple GPUs?
Sure, I used jax.pmap versions of functions to distribute across multiple GPUs
# setup the update function
@functools.partial(xarray_jax.pmap, dim='device', axis_name='device')
def multi_gpu_update(params, state, opt_state, inputs, targets, forcings):
# calculate loss and gradients
loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, inputs, targets, forcings)
# combine the gradients across devices
grads = jax.lax.pmean(grads, axis_name='device')
# combine the loss across devices
loss = jax.lax.pmean(loss, axis_name='device')
# update
updates, new_opt_state = optimiser.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, loss, new_opt_state
# setup the loss function for evaluation
@functools.partial(xarray_jax.pmap, dim='device', axis_name='device')
def multi_gpu_loss(params, state, inputs, targets, forcings):
# calculate loss
(loss, diagnostics), next_state = loss_fn_jitted(params, state, jax.random.PRNGKey(0), inputs, targets, forcings)
# combine the loss across devices
loss = jax.lax.pmean(loss, axis_name='device')
return loss
Thanks @ChrisAGBlake and @monte-flora for sharing, i have saved and loaded the model respectivelly with the save_model_params and load_model_params functions, but I have trouble doing predictions with it, if i use the Autoregressive rollout (loop in python) and loss computation code it seems to be loss of the randomly initialized model. The loss should be way lower than what I got.
prediction = rollout.chunked_prediction(
run_forward_jitted,
rng=jax.random.PRNGKey(0),
inputs=train_inputs,
targets_template=train_targets * np.nan,
forcings=train_forcings)
predictions.append(prediction)
if i in [0,1,2]:
# @title Loss computation (autoregressive loss over multiple steps)
loss, diagnostics = loss_fn_jitted(
rng=jax.random.PRNGKey(0),
inputs=train_inputs,
targets=train_targets,
forcings=train_forcings)
print("Loss:", float(loss))
It is as if the new loaded parameters are not taken into account in the predictions. Any clue on how to make this work ? Thank you
Thanks @ChrisAGBlake. I encountered a problem. My GPU memory only has 12 G, and OOM occurs when calculating loss and gradient. Is there any solution?
You'll probably need more GPU memory than that unless you're trying to train a low resolution model. One workaround is to offload to the CPU which you can do by putting the following couple of lines at the top of your script.
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'
This will try and use your CPU to double the amount of memory you have available from the GPU - 24GB in total. You can adjust '2.0' to whatever value is required to be able to do the update (within your RAM availability).
I have found that I require ~80GB with a model at 0.25 deg resolution.
Another solution is just to use vast.ai, AWS, GCP etc
@vsansi I load the trained model like this
# load the model
with open(checkpoint_file, 'rb') as f:
ckpt = checkpoint.load(f, graphcast.CheckPoint)
params = ckpt.params
state = {}
model_config = ckpt.model_config
task_config = ckpt.task_config
Then I generate predictions like this:
# run forward autoressively to get the predictions
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))
predictions = rollout.chunked_prediction(
run_forward_jitted,
rng=jax.random.PRNGKey(0),
inputs=inputs,
targets_template=targets * np.nan,
forcings=forcings)
# write out the predictions
pred_file = f'{save_dir}/{date.strftime("%Y-%m-%dT%H")}_predictions.nc'
predictions.to_netcdf(pred_file)
Hope this helps.