Jake Vanderplas
Jake Vanderplas
Thanks @ykirpichev – I think this is a pretty good start! Could you squash all the changes into a single commit, and then we'll start the process of pulling it...
Hi - thanks for the question! I think this is behaving as expected. When it comes to autodiff, sparse matrices are fundamentally different than dense matrices in that *zero elements...
It seems like this would best fit in an advanced vmap guide, probably related to the work in #18585 (which I've neglected for far too long...)
Hi - thanks for the report! The `Rotation` functionality has some implementation issues, and is a part of the package that we've identified (retroactively) as out-of-scope for JAX (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-spatial),...
> It is funny that you mention the array API. From what I can understand from the scipy issue ([scipy/scipy#18286](https://github.com/scipy/scipy/issues/18286)) on the matter, they are hoping to "dispatch" this kind...
Interesting question! I suspect the reason for the performance difference here is that the GPU hardware is designed and tuned for float32 computation, and not for float16 computation. It would...
No, I don't think such conversions are happening – you can see exactly what operations the compiler is emitting using [ahead of time lowering](https://jax.readthedocs.io/en/latest/aot.html) to output the compiled HLO. This...
My best guess still is that the hardware you're using is not optimized for the kinds of operations you're performing (i.e. scatters) in float16, and is more optimized for float32....
I think these are not equivalent operations – wouldn't `torch.scatter` be equivalent to JAX scatter, not JAX segment sum?
Ah, thanks for the clarification. Looks like it is doing the same thing – I'm not sure why JAX's version is slower.