mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Question] Differences between MLX and JAX

Open smorad opened this issue 1 year ago • 2 comments

This is somewhat related to https://github.com/ml-explore/mlx/issues/12, although I can see how mlx improves significantly upon torch. My question is why reinvent the wheel with mlx, when the core of mlx seems to closely follow jax.

The mlx documentation lists these differences between torch/jax:

The design of MLX is inspired by frameworks like PyTorch, Jax, and ArrayFire. A notable difference from these frameworks and MLX is the unified memory model. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without performing data copies. Currently supported device types are the CPU and GPU.

Is there some part of jax that prevents utilizing shared memory? Unlike torch, you are generally not shuffling data around with .to(device) in jax. Couldn't Apple write an OpenXLA backend that avoids copies and executes operations on Neural Engine/GPU/CPU devices as necessary?

The other major difference between jax and mlx seems to be the lazy evaluation. Lazy evaluation is not possible in jax, but doesn't jax.compile solve a similar problem to lazy evaluation? If you compile the outermost function, then unused variables, unnecessary function calls, etc are compiled away. Are there situations where one would prefer lazy evaluation over compilation?

I suppose it seems like a shame to me that the mlx and jax interfaces are so similar, because I could imagine how nice it would be to prototype a model on my Macbook, then deploy to a CUDA cluster for larger-scale training. I would imagine that this would also provide a smoother transition for researchers to move away from CUDA as Apple's ML hardware improves.

smorad avatar Oct 12 '24 13:10 smorad

Long-standing issues like https://github.com/jax-ml/jax/issues/16321 are also one of the differences.

Downchuck avatar Oct 16 '24 22:10 Downchuck

Long-standing issues like jax-ml/jax#16321 are also one of the differences.

Seems the same actually. Cholesky on MLX is also only supported on the CPU.

vhaasteren avatar Apr 02 '25 10:04 vhaasteren

It is too hard to answer what are the differences between 2 frameworks, it would be much easier to answer what are in common actually. But I think this issue is more about why Apple did not invest in JAX.

As someone who did not involve in the original decision, I'm very happy with the existence of MLX: it is so much easier to understand how machine learning works by reading MLX's code, on the other hand the source code of JAX and XLA are heavily layered with history burdens and require significant more efforts to read.

From engineering's side, it is feasible for an amateur to add a new backend MLX in reasonable time (as the CUDA backend shows), but adding new backends to JAX would be a much larger challenge.

So I think the answer in https://github.com/ml-explore/mlx/issues/12#issuecomment-1843956313 still applies to this question, and I'm closing this issue.

zcbenz avatar Nov 13 '25 23:11 zcbenz