Jake Vanderplas
Jake Vanderplas
When I try benchmarking your original function using `jax.jit`, I find that JAX is 4x faster than autograd on both CPU and GPU for inputs of size 1000 ```python import...
It looks like these functions can't be imported directly, because `jax.nn.softmax` defaults to `axis=-1`, while `scipy.special.softmax` defaults to `axis=None`. We'll have to create wrappers for these in `jax/_src/scipy/special.py`.
> We'll have to create wrappers for these in `jax/_src/scipy/special.py`. Also, we might want to make the `jax.scipy.special` wrappers use the non-deprecated softmax version by default. The deprecated version has...
Thanks for the report. It looks like the compiler is able to recognize the constant expression in one case, but not in the other. I don't think I'd consider this...
We can see what the compiler is doing with these functions using [ahead of time lowering](https://jax.readthedocs.io/en/latest/aot.html). For example: ``` h1_lowered = jax.jit(h1).lower(x, a, b, c).compile() h2_lowered = jax.jit(h2).lower(x, a, b,...
I think I understand the difference: the expensive allocation is the output of the `einsum`. In `h1`, the input to the einsum is an internal buffer (the output of `x...
Hi - thanks for the question! Could you take a look at https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code and update your benchmarks? In particular, accounting for asynchronous dispatch via `block_until_ready()` and separating-out compile time and...
Thanks! A couple things: 1. Since you're running on CPU, you might also try on GPU. The XLA GPU compiler is a bit more mature than the CPU compiler, so...
If the problem is slow convergence, I would check the learning rate. `0.0000005` seems very small.
Without digging further, my guess would be that you've found a local optimum.