scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

Error when trying to feature native MLX accelerated scvi

Open c0nleyinnnn opened this issue 9 months ago • 4 comments

Description:

Before the recent update to PyTorch's MPS support, using PyTorch MPS to accelerate scVI would result in NaN values in the returned matrix. For more details, see: Error when training model on M3 Max MPS.

I attempted to port the MLX framework to scVI by following the simple-scvi guide. I rewrote _mlxvae.py, _mlxscvi.py, and _mlxmixin.py to enable backend calls to the MLX framework for Metal acceleration. This decision was influenced by findings that the MLX framework can invoke Metal GPU computations at a higher frequency compared to PyTorch MPS. For more information, see: phi2-llm-on-MLX-vs-Pytorch-MPS.

With the help of various development tools and AI, the code now runs in Python using MLX. However, the returned latent matrix exhibits a similar issue to the previous PyTorch MPS problem:

latent
array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]])

Due to my limited mathematical and programming skills, I am unable to resolve this issue or even identify where to start. A previous issue mentioned problems with lgamma, but I couldn't find the corresponding part in my code.

Additionally, the reason I attempted this port is that the batch effect removal in scVI accelerated by PyTorch MPS differs from the results obtained using CUDA acceleration on my Windows host. The current MLX port runs in almost the same time as PyTorch MPS on the same test data but is significantly faster than the jaxscvi function. This might be due to my lack of development expertise, as I cannot ensure proper compilation or logical efficiency.

If anyone can provide insights, assistance, or even take over this project, that would be fantastic. I am willing to contribute the existing code for free, and I hope someone can help develop more efficient and stable single-cell omics analysis tools for Apple Silicon.

Additional Context:

  • The issue with lgamma was mentioned in a previous discussion, but I couldn't locate it in my code.
  • The performance of the MLX port is comparable to PyTorch MPS but faster than jaxscvi, though this might be due to suboptimal development practices on my part.

Any help or collaboration would be greatly appreciated!

current works:https://github.com/c0nleyinnnn/mlxSCVI

c0nleyinnnn avatar Mar 23 '25 06:03 c0nleyinnnn

I really like this idea and look forward how it goes. It will most probably be faster than metal PyTorch. It’s a major undertaking to do this port though and I’m not familiar with mlx. The lgamma thing is a PyTorch issue (we use broadcasting for the reconstruction loss and broadcasting is currently not handled correctly by metal PyTorch). Do you get None values at the first step or how long does it take to get None values?

canergen avatar Mar 23 '25 07:03 canergen

I really like this idea and look forward how it goes. It will most probably be faster than metal PyTorch. It’s a major undertaking to do this port though and I’m not familiar with mlx. The lgamma thing is a PyTorch issue (we use broadcasting for the reconstruction loss and broadcasting is currently not handled correctly by metal PyTorch). Do you get None values at the first step or how long does it take to get None values?

Yes, the problem and workload were much more difficult than I thought. I mainly referred to the work of jaxscvi, but in actual operation, nan values ​​appeared in the first round. Until the end, when I tried to add some defensive numerical processing, although the values ​​were generated, the subsequent plots showed that the results were unreliable and the running speed was twice slower than at the beginning. As I said, I am not a developer or statistician, and what I can do is very limited.

c0nleyinnnn avatar Mar 23 '25 08:03 c0nleyinnnn

I really like this idea and look forward how it goes. It will most probably be faster than metal PyTorch. It’s a major undertaking to do this port though and I’m not familiar with mlx. The lgamma thing is a PyTorch issue (we use broadcasting for the reconstruction loss and broadcasting is currently not handled correctly by metal PyTorch). Do you get None values at the first step or how long does it take to get None values?

An update for u: I have now basically reproduced the core module of SCVI, but the code is currently running very slowly and with a significantly abnormal loss, after some troubleshooting I think it's due to the lack of a native mlx.lgamma function, but rather a lanczos approximation built in python. I've now put in an official request at mlx, and until it responds I think this work will be put on hold indefinitely!

c0nleyinnnn avatar Apr 07 '25 06:04 c0nleyinnnn

@c0nleyinnnn note I added your code + several fixes into a PR of our repo, feel free to contribute to it by forking this branch.

ori-kron-wis avatar Nov 12 '25 14:11 ori-kron-wis