mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Feature] Add support mx.random.multivariate_normal()

Open tedwards2412 opened this issue 1 year ago • 10 comments

@NNSSA and I are working on a sampling package for mlx (https://github.com/tedwards2412/samplex) and it would be extremely useful to have these two functions to do more generic sampling. The latter will involve adding more functionality to the core.linalg sub-package. Is this likely to come in a future update? Happy to help if needed!

tedwards2412 avatar Jan 19 '24 17:01 tedwards2412

Let's do these as two issues as diag is much easier than multivariate normal. I assume for multivariate normal you need a non-diagonal covariance?

awni avatar Jan 19 '24 19:01 awni

Leaving this issue for mx.random.multivariate_normal and created #503 for diag

awni avatar Jan 19 '24 19:01 awni

FYI: for multivariate normal we probably 🤔 need matrix inversion e.g. mx.linalg.inv. Which will also probably help with other things.

awni avatar Jan 19 '24 20:01 awni

Great, thanks! And yes, non-diagonal covariance would be essential for this.

tedwards2412 avatar Jan 19 '24 20:01 tedwards2412

Cool package by the way! You should add a little quick start/usage guide (when it's ready for it).

awni avatar Jan 19 '24 20:01 awni

Thanks! Will definitely do :-) Related, I think numpy uses a singular value decomposition to compute a multivariate normal.

NNSSA avatar Jan 19 '24 21:01 NNSSA

We have a PR out for QR #310. I think SVD and Cholesky would go similarly. The main issue is there are no Metal implementations for most of Lapack so a lot of this will be CPU only until we can get some kernels implemented.

awni avatar Jan 19 '24 21:01 awni

@awni We finally got round to adding the quick start you recommended on https://github.com/tedwards2412/samplex. The sampling seems to work well with mlx so far. Looking forward to working on this more in the future!

tedwards2412 avatar Feb 02 '24 22:02 tedwards2412

That's awesome!! Out of curiosity, could you tell me a bit more about (some) intended uses for the package? I would love to point people to it if you are ok with that and I understand a bit more in what cases you are targeting.

awni avatar Feb 02 '24 22:02 awni

Overall the goal is to see if we can allow people to quickly run fairly large scale MCMC sampling locally rather than having to run on a cluster. But it's also just a research project for us to see if there are particular sampling algorithms that are substantially better when you can switch between running on CPU and GPU without any overhead; I don't think has been explored at all before.

tedwards2412 avatar Feb 05 '24 14:02 tedwards2412

This was closed a while ago.

awni avatar Aug 10 '24 13:08 awni