Awni Hannun
Awni Hannun
Let's do these as two issues as `diag` is much easier than multivariate normal. I assume for multivariate normal you need a non-diagonal covariance?
Leaving this issue for `mx.random.multivariate_normal` and created #503 for `diag`
FYI: for multivariate normal we probably 🤔 need matrix inversion e.g. `mx.linalg.inv`. Which will also probably help with other things.
Cool package by the way! You should add a little quick start/usage guide (when it's ready for it).
We have a PR out for QR #310. I think SVD and Cholesky would go similarly. The main issue is there are no Metal implementations for most of Lapack so...
That's awesome!! Out of curiosity, could you tell me a bit more about (some) intended uses for the package? I would love to point people to it if you are...
This was closed a while ago.
The gradient of something like an `argmax` would be zeros almost everywhere.. I don't think that's really the point of this discussion (which is very nice), but probably we should...
This is what I see Jax did FWIW: ``` import jax import jax.numpy as jnp def fun(x): return jnp.argmax(x) x = jnp.array([1.0, 2.0, 3.0]) out, vjf = jax.vjp(fun, x) print(vjf(jnp.array(1)))...
While I tend to agree with you, I also find the inconsistency with other zero-grad ops a bit incongruous (e.g. `a > b`). Maybe a good compromise is to default...