Michael O'Brien
Michael O'Brien
Thanks so much for the timely response! This is very helpful. I use the approximation to analytically compute marginal posteriors via gaussian integrals in some of my parameters (so in...
Sorry for the late reply. This is very helpful and makes sense, particularly with regard to how to think about the noisy loss surface.
Ah okay thank you for the insight—I had been reading through this example of the advanced API, but I assumed there would be an issue with the fact that each...
Okay this really clears things up! Thanks so much, this was puzzling me. I had not thought through this in context of the general behavior of `jax.vmap(jax.lax.while_loop)`.
For more context, I develop a JAX package for scientific application (https://github.com/mjo22/cryojax/). In my discipline, batching is very common but there are many different ways downstream users may want to...
I don’t have strong preference between generalizing the jax.vmap vs writing a new API! I can definitely see why the latter would be preferable and perhaps even more powerful. Taking...
@mattjj @carlosgmartin opened a new issue for the more general points discussed here: #30528
Aha, the solution in #242 is something like I had in mind. My only comment is that my modules can have non-arrays as leaves, so using `filter_vmap` under the hood...
FYI I'll just note here that as of JAX 0.7.0, python >=3.11 is required. This wouldn't be related to the issue but is important to know.
The only way I can think of doing it is the following, which feels like a hack and would lead to a confusing user API: ```python class ForwardedModule(AbstractModule): some_field: int...