mchagneux
Results
2
comments of
mchagneux
Definitely agree. I sometimes wrap `jax.vmap` with something like this: ``` def named_vmap(f, axes_names, **kwargs): in_axes = ({k:0 if k in axes_names else None for k in kwargs.keys()},) return jax.vmap(lambda...
Ok, very clear thank you for answering !