Dan Foreman-Mackey
Dan Foreman-Mackey
Looks good - thanks!
@ChenAo-Phys — Thanks for the report! And thanks to @jaro-sevcik for suggesting this workaround. > By the way, is there any plan to solve the sharding problem in the future?...
@ChenAo-Phys — The LU factorization on GPU should now (with JAX v0.6.0) automatically shard properly in cases like this. Want to try again in your environment if this is still...
Excellent! I'm going to close this now, but I will note that the Cholesky factorization still doesn't shard properly on GPU. All the other factorizations should work! It's on my...
@trevor-m — Thanks for your patience here! Can you rebase your PR onto the current `main` branch? We'll get this in ASAP after that. Thanks!
It looks like all the new tests need to be conditioned on the jaxlib version. The issue here is that most of the CI jobs run with the released version...
It also looks like you need to add an explicit dependency on `:path` here: https://github.com/jax-ml/jax/blob/4b4fb9dae9eb7e2740d70de5b4a610f979530382/jax/BUILD#L425-L438
@trevor-m — This should be good to go now! ~~but can you rebase onto the current main branch so that we can rerun the import?~~ Quite a few internal google...
> That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side. This actually _isn't_ the behavior of `vectorized`! I know that the...
Are you sure you want `assert z.shape == ()`? My suggestion was that you write: ```python @partial(jax.vmap, in_axes=(0, None, None)) @partial(jax.vmap, in_axes=(None, 0, None)) @broadcasting_vmap #