Chris Blake
Chris Blake
I managed to get it working with a few modifications to the example notebook code. Here are the changes I made (ignoring all unchanged code from the example): ```python import...
Sure, I used jax.pmap versions of functions to distribute across multiple GPUs ```python # setup the update function @functools.partial(xarray_jax.pmap, dim='device', axis_name='device') def multi_gpu_update(params, state, opt_state, inputs, targets, forcings): # calculate...
You'll probably need more GPU memory than that unless you're trying to train a low resolution model. One workaround is to offload to the CPU which you can do by...
@vsansi I load the trained model like this ```python # load the model with open(checkpoint_file, 'rb') as f: ckpt = checkpoint.load(f, graphcast.CheckPoint) params = ckpt.params state = {} model_config =...
This works for me to download and format 1 days worth of data, you can modify this for multiple days ```python import cdsapi import os import datetime from netCDF4 import...
@LipJ01 Have you had any success with this? I'm also looking to export trained models on a custom environment to tensorflow and then ONNX. @danijar Thank you so much for...
I managed to convert it to tensorflow but wasn't able to convert to tflite or onnx as it seems to use some operations that aren't supported by those. For anyone...