scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

Add torch neuron support for scvi tools

Open LinearParadox opened this issue 1 year ago • 3 comments
trafficstars

Torch neuron is a PyTorch architecture that enables it to use AWS based Trainium and Inferentia gpu instances. Since these are somewhat cheaper, especially for large models that may be a little too large for GPUs such as T4s, but also not worth an A100. It would be a nice addition

LinearParadox avatar Mar 19 '24 18:03 LinearParadox

I'm not familiar with torch-neuron, what sort of changes would be necessary in order to enable this?

martinkim0 avatar Mar 19 '24 19:03 martinkim0

I think there would likely have to be some analogous code added. For example something like:

if use_neuron:
    Neuron training code
else:
    normal training code

The API seems pretty analogous to PyTorch:

https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/training/pytorch-neuron-programming-guide.html#pytorch-neuronx-programming-guide

I'm not super experienced with torch, but I can also try to dig in after this week to try and see if it's a trivial modification or entails a larger redesign.

The one major difference that might pose an issue is that neuron builds graphs lazily, while PyTorch doesn't. Not sure how impactful this will be practically though

LinearParadox avatar Mar 19 '24 19:03 LinearParadox

Let me know if you're able to look into this! Would be happy to take a PR if it's a small modification. If it seems like it's going to be a larger redesign, I think it would make sense for us to discuss it internally before anything is implemented.

martinkim0 avatar Mar 20 '24 03:03 martinkim0