mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Add utilities to distributions

Open noahfarr opened this issue 1 year ago • 3 comments

In pytorch, the following is easily possible:

logits = ...
probs = Categorical(logits=logits)
log_prob = probs.log_prob(value)
entropy = probs.entropy()

but when I want to achieve something similar in MLX, I have to manually calculate the log_prob and entropy. Is it possible to add support for these methods as it makes working with distributions in MLX much more convenient (at least for me)

noahfarr avatar Nov 12 '24 17:11 noahfarr

We don't have a distributions package in the same way the PyTorch does and don't currently have a plan yet to add one to MLX core.

We can keep this issue open to gauge priority.. but for now it's not likely something we will implement in the near future. A good short-term strategy is to encourage a third party development of mlx-distributions or something.

awni avatar Nov 12 '24 21:11 awni

Thanks for your quick response :)

noahfarr avatar Nov 13 '24 09:11 noahfarr

We don't have a distributions package in the same way the PyTorch does and don't currently have a plan yet to add one to MLX core.

We can keep this issue open to gauge priority.. but for now it's not likely something we will implement in the near future. A good short-term strategy is to encourage a third party development of mlx-distributions or something.

To be honest, the torch.distributions is very important, even on Jax we can find something similar, manually create distribution is quite painful for bioinformatics developing.

c0nleyinnnn avatar Mar 24 '25 14:03 c0nleyinnnn