Jake Vanderplas
Jake Vanderplas
It seems to me the behavior of `expand_dims` with multiple axes could be specified without any of the ambiguities mentioned above. Basically, for `y = expand_dims(x, axes)` when `axes` is...
It looks like you need something like this at the beginning of your function ```python if jax.dtypes.issubdtype(var.dtype, jax.dtypes.prng_key): var = jax.random.key_data(var) impl = jax.random.key_impl(var) ``` And if you need to...
How hard would it be to add a shift-invert mode to the current lobpcg solver?
Hi - thanks for the question! Your benchmarks may not be measuring what you think they are here, because you're not accounting for JIT compilation time or for asynchronous dispatch....
Your updated script doesn't seem to isolate compilation time – even if you're not using `jax.jit` yourself, it is used internally by APIs you are calling.
Let's stick with `ddof=0` as a default. Simplest is best here I think
Thanks for the report. I was able to reduce the problem to this: ```python import jax jax.lax.random_gamma_grad(27708268.0, 27708266.0) ``` These values are being generated and passed to `gamma_grad` in your...
Investigating further, it looks like the non-convergence occurs whenever both inputs are above `2 ** 24`, which is (probably not coincidentally) the number of significant bits in the mantissa of...
Looks good – I suspect `std` will have the same issue. Fix that here as well?
This is probably too duplicative of `jax.tree_util.register_dataclass`. Closing.