Docs: please clarify how to vmap nnx.Module over batch dimension
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): wsl2 ubuntu 22.04
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib: 0.8.5, 0.4.30, 0.4.30 - Python version: 3.12 and whatever colab uses today
- GPU/TPU model and memory: nvidia rtx 4090 24gb
- CUDA version (if applicable): 12.5
Problem you have encountered:
made a nnx.Module to apply to single (1) input output pair example at a time using vmap the model works on the first pass, but then on the 2nd pass it crashes
What you expected to happen:
expected vmapping a nnx.Module over a batch of data would not pass some shape on the first call and a different shape on the second call
Thought the docs could help but there's no mention of batch handling in there, seems like the batching for the given mnist example is assumed, which implies autobatching, but that just crashed when i tried it
tried various combinations of nnx jit, jax jit, and they all crashed with various shape bugs doing one example of mnist at a time with no batching worked, but batching is pretty necessary
dont want to hard code models for some particular batch size, can you please clarify in docs how can we map a nnx.Module over a batch without needing to hardcode the batch size in the kernel dimensions? or, what am i doing wrong in the colab?
side note: i also hit bugs with dropout, the rngs context would crash, deterministic or not i also wanted to make a dynamic length SGLD scan and early stop when the hessian indicated an optima but jax can't do this apparently :(
Logs, error messages, etc:
with jax.vmap
with nnx.vmap
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
NOTE: TO SAVE YOUR TIME, you can skip reading the 2nd codeblock as it's just copypasta to load mnist https://colab.research.google.com/drive/11gKSydAn3p_fLdQaqw7pkdHmC2zyW12k?usp=sharing
Hey! In the NNX Transforms section I think we make it clear that you must use the nnx.* version of the transform when using NNX objects. Is there something we should improve here?
yes, you're right, the docs are clear to use the nnx transforms
I hit an issue where, I just want to map a function over each item in a batch, and when using the vmap transform, the first time I call the transformed function, it receives a single slice of the batch, which is what I was going for, but then when it gets called again, it receives the whole batch at once, which caused it to crash due to a shape issue in a call to "concatenate"
I reckon it's using an XLA version of the same function on the second invocation
For what to improve, could the docs have a section dedicated to advice about how to work with batches of data?
I was looking for more focused troubleshooting docs to deal with shape bugs in handling batches with vmap, or advice on how to apply a function over slices of a batch independently without having to change the function to handle the whole batch at once
(it's an EBM / energy based model, so I don't want to inappropriately mix energy gradients across the members of the same batch, and it breaks if I try to apply the single-item function to the whole batch)
Might just be my general inexperience with jax and a vmap issue and not a nnx issue; maybe i am missing something about how vmap works
I think currently you want this:
(loss_values, (y_preds, the_energies, accuracies)), grads = nnx.vmap(
loss_value_and_grad_fn,
in_axes=(None, 0, 0),
state_axes={...: None}, # <<== add this
)
But in the near future of #3963 , you will be able to do
(loss_values, (y_preds, the_energies, accuracies)), grads = nnx.experimental.vmap(
loss_value_and_grad_fn,
in_axes=(None, 0, 0), # <<== just works
)
and hopefully experimental becomes the norm