neuralgcm icon indicating copy to clipboard operation
neuralgcm copied to clipboard

Batched Inference Possible?

Open leanderloew opened this issue 1 year ago • 4 comments

Hey, thanks again for this project. I was wondering if it's possible to do batched inference with the current API. Thanks!

leanderloew avatar May 17 '24 08:05 leanderloew

Yes, it should work.

Just use jax.vmap to wrap various methods, e.g.,

initial_state = jax.vmap(model.encode)(...)
final_state, predictions = jax.vmap(model.unroll)(initial_state, ...)

You may need to adjust the in_axes argument and/or write a helper function.

I have not tested it, but I expect that this could result in significant speed-ups for inference on GPU, especially for coarser resolution simulations.

shoyer avatar May 17 '24 17:05 shoyer

Thanks again for the pointer. I managed to set it up like this:

def get_random_state():
    return neural_gcm_model.encode(
        inputs, forcings=input_forcings, rng_key=jax.random.PRNGKey(randint(0,191398193819238)))

def stack_initial_states(initial_states):
    return tree_multimap(lambda *xs: jnp.stack(xs), *initial_states)

def get_initial_state(n_states:int):
    return stack_initial_states([get_random_state() for x in range(n_states)])

def unroll_fn(initial_state):
    return neural_gcm_model.unroll(
        initial_state, forcings=forcings, steps=outer_steps,
        timedelta=timedelta, start_with_input=True)

vectorized_unroll_fn = jax.vmap(unroll_fn)
final_states, predictions = vectorized_unroll_fn(get_initial_state(i))

However, sadly I am not getting any speedup over just running in a loop. I am not too familiar with Jax did I make any obvious mistake?

leanderloew avatar May 20 '24 17:05 leanderloew

I checked the results and they look good. So the computation is happening correctly. Wrapping vmap into a jax.jit also didn't help.

leanderloew avatar May 20 '24 17:05 leanderloew

This looks right to me at high level.

I'm honestly not sure exactly what to expect here. NeuralGCM already uses a fair amount of parallelism inside each model, so it's possible there isn't much speed-up available.

One thing to check is to make sure the forward computation is jit compiled, by wrapping jax.jit around your vmapped function and then calling it twice to ensure you aren't measuring tracing/compilation time: https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code

If you want to do a deep dive into performance, take a look at JAX's profiling tools: https://jax.readthedocs.io/en/latest/profiling.html

shoyer avatar May 20 '24 18:05 shoyer

@leanderloew were you ever able to resolve this. I am having a similar issue and followed a similar procedure yet unable to figure it out.

mcrlf avatar Sep 20 '24 17:09 mcrlf