neuralgcm
                                
                                 neuralgcm copied to clipboard
                                
                                    neuralgcm copied to clipboard
                            
                            
                            
                        Batched Inference Possible?
Hey, thanks again for this project. I was wondering if it's possible to do batched inference with the current API. Thanks!
Yes, it should work.
Just use jax.vmap to wrap various methods, e.g.,
initial_state = jax.vmap(model.encode)(...)
final_state, predictions = jax.vmap(model.unroll)(initial_state, ...)
You may need to adjust the in_axes argument and/or write a helper function.
I have not tested it, but I expect that this could result in significant speed-ups for inference on GPU, especially for coarser resolution simulations.
Thanks again for the pointer. I managed to set it up like this:
def get_random_state():
    return neural_gcm_model.encode(
        inputs, forcings=input_forcings, rng_key=jax.random.PRNGKey(randint(0,191398193819238)))
def stack_initial_states(initial_states):
    return tree_multimap(lambda *xs: jnp.stack(xs), *initial_states)
def get_initial_state(n_states:int):
    return stack_initial_states([get_random_state() for x in range(n_states)])
def unroll_fn(initial_state):
    return neural_gcm_model.unroll(
        initial_state, forcings=forcings, steps=outer_steps,
        timedelta=timedelta, start_with_input=True)
vectorized_unroll_fn = jax.vmap(unroll_fn)
final_states, predictions = vectorized_unroll_fn(get_initial_state(i))
However, sadly I am not getting any speedup over just running in a loop. I am not too familiar with Jax did I make any obvious mistake?
I checked the results and they look good. So the computation is happening correctly. Wrapping vmap into a jax.jit also didn't help.
This looks right to me at high level.
I'm honestly not sure exactly what to expect here. NeuralGCM already uses a fair amount of parallelism inside each model, so it's possible there isn't much speed-up available.
One thing to check is to make sure the forward computation is jit compiled, by wrapping jax.jit around your vmapped function and then calling it twice to ensure you aren't measuring tracing/compilation time:
https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code
If you want to do a deep dive into performance, take a look at JAX's profiling tools: https://jax.readthedocs.io/en/latest/profiling.html
@leanderloew were you ever able to resolve this. I am having a similar issue and followed a similar procedure yet unable to figure it out.