scvi-tools
scvi-tools copied to clipboard
Multi GPU training
- Write a custom
DistributedSamplerthat also takes as input the overall set of indices to pull data from (i.e., train set or test set or val set indices). Probably just need to add a few lines of code to init, call super init and the write a custom iter method - Use this sampler if multi gpu training selected (these are all through kwargs of pytorch lightning trainer)
I would be also very interested in multi-GPU training of pyro models, specifically full data training mode where large data is split between different GPUs.
It is actually fairly straightforward to do data parallelism in pyro using horovod (https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_horovod.py). This would mainly require 1) a new training plan (to use DistributedSampler, a different optimizer) and 2) a different device-backed data loader (to split data across devices).
There are issues with using this for models with both global and local cell-specific parameters (all parameters live on all devices).
It is actually fairly straightforward to do data parallelism in pyro using horovod (https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_horovod.py). This would mainly require 1) a new training plan (to use DistributedSampler, a different optimizer) and 2) a different device-backed data loader (to split data across devices).
It is likely less straightforward to make this work with PyTorch Lightning and might require substantial work to make this work generally for Pyro. In particular we'd have to look more into how device backed data loaders would work in this case.
I see. I also got some feedback from @fritzo that for models with local parameters this might not give as much space increase as one would hope because cell-specific parameters are quite large for just a few 100k cells (but this could give 4-5x more space). I don't quite understand PyTorch Lightning so if you solve this - I would be very keen to try.
Does numpyro+jax more natively support multi-GPU training? If yes this could be a way to go.
What I am specifically interested in is data and model parameter parallelism where the data and model parameters for different cells (denoted by a plate) are distributed to different GPU devices. Maybe this is also possible with pyro.
Also cc @fehiepsi @fritzo @martinjankowiak
[As metntioned above] Pyro can use Horovod for data parallelism across GPUs and machines in a cluster, but I believe parameters would be replicated on all nodes. NumPyro might be the way to go. @fehiepsi?
Current NumPyro SVI does not support that pattern but it might be able to do using JAX. Something like
def loss_fn(batch, params):
global_params, local_params = params
model_g = handlers.substitute(model, data=global_params)
guide_g = handlers.substitute(guide, data=global_params)
def get_loss_local(data, local_params):
model_l = handlers.substitute(model, data=local_params)
guide_l = handlers.substitute(guide, data=local_params)
loss = TraceELBO(model_l, guide_l, ...)
return loss
return jax.pmap(get_loss_local)(batch, params)
# then use jaxopt to optimize loss_fn over params: https://jaxopt.github.io/stable/stochastic.html#optax-solvers
though still seems to be a bit tricky to cover many usage cases (like when there are both global variables and local variables, we need to apply reduced sum at local variables).
Thanks for your thoughts!
My models always have both local and global variables. Do you see any way to define device split along the pyro plate? Maybe that could be provided as option in numpyro?
On Wed, 13 Apr 2022, 12:02 Du Phan, @.***> wrote:
Current NumPyro SVI does not support that pattern but it might be able to do using JAX. Something like
def loss_fn(batch, params): global_params, local_params = params model_g = handlers.substitute(model, data=global_params) guide_g = handlers.substitute(guide, data=global_params)
def get_loss_local(data, local_params): model_l = handlers.substitute(model, data=local_params) guide_l = handlers.substitute(guide, data=local_params) loss = TraceELBO(model_l, guide_l, ...) return lossreturn jax.pmap(get_loss_local)(batch, params)
then use jaxopt to optimize loss_fn over params: https://jaxopt.github.io/stable/stochastic.html#optax-solvers
though still seems to be a bit tricky to cover many usage cases (like when there are both global variables and local variables, we need to apply reduced sum at local variables).
— Reply to this email directly, view it on GitHub https://github.com/scverse/scvi-tools/issues/1226#issuecomment-1097916251, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFMFTV4CKS4EOBEXVKDTLYTVE2SURANCNFSM5F5UWUGQ . You are receiving this because you commented.Message ID: @.***>
Does pure data parallelism work with current numpyro and scvi-tools? (Loading different minibatches of data to different devices, 8 minibatches in parallel but different cells in each training iteration)
@adamgayoso sorry, do you have a recipe if I want to enable multi-GPU training of scVI model before it is released scvi-tools? I haven't done multi-GPU training before, so I'm asking where to start. Can I just apply a patch from #1357 ?
@adamgayoso If you ignore device-backed data loaders for now, what is the main roadblock to implementing the Pyro+horovod solution? https://pyro.ai/examples/svi_horovod.html
Does this boil down to implementing an equivalent to torch.utils.data.distributed.DistributedSampler and modifying the training plan to add horovod use? Or is there more to it?
Is the problem in writing a general solution that works for any model OR is the problem that this won't work for any model?
Also cc @macwiatrak for discussion
I expect that we only have minimal issues with non-Pyro models. This is due to updates in lightning that automatically wrap custom dataloader samplers like we have.
In the case of Pyro, we have a somewhat hacky solution that fuses it with a lightning module. I expect significantly more engineering work to get this right. A hacky solution might be quicker, but we shouldn't include that in this library.
To clarify, lightning should handle:
- Automatically creating the distributed data loader (recent updates should allow this to work with no code changes on our side)
- Broadcasting the params and optimizers across devices
But this is in the default pytorch case. For Pyro, which lazily initializes params, the hacky solution would involve a callback that does some of the things you see in the linked pyro tutorial.
@adamgayoso I'd love to make this easier to do in Pyro (as @vitkl has requested). What's your timeline? Could we sync the week of the Jan 23 to figure out what would be needed on the Pyro side?
We don't have bandwidth to contribute much at the moment, but can review code. I think it's relatively straightforward to make this work in the nn.Module/PyroModule paradigm by altering what we call a TrainingPlan to use vanilla torch optimizers instead of Pyro optimizers. This will allow lightning to do almost all the work.
In other words, we can create a LowerLevelPyroTrainingPlan, using this lower level pattern internally.
loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, X_train, y_train)
with pyro.poutine.trace(param_only=True) as param_capture:
loss = loss_fn(model, guide)
params = set(site["value"].unconstrained()
for site in param_capture.trace.nodes.values())
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.90, 0.999))
Then lightning will do everything it needs to do with handling backprop.
We already have a callback to run the param initialization, which can then be changed to reset the optimizer
@fritzo what is ELBOModule? This seems useful. I can maybe put up a draft PR with my idea
@eb8680 can tell you more about ELBOModule
It seems @vitkl has made it work via the new lower level trainingplan https://github.com/scverse/scvi-tools/pull/1845
It appears that the lightning "ddp_notebook" strategy + https://github.com/scverse/scvi-tools/pull/1845 doesn't allow saving the trained model because lightning fails to load checkpoints saved by the worker GPU. Looks like when the state_dict is loaded the model parameters don't exist. I tried creating a callback that would do a forward pass through the model on_fit_end, on_train_end and on_load_checkpoint - however, none of that changed anything. Maybe this means that the strategy or the training plan needs to be modified but I don't understand what can be done next. Any tips would be appreciated @fritzo @eb8680 @adamgayoso
I have not tested the standard "ddp" strategy in a script yet.
File /nfs/team205/vk7/sanger_projects/my_packages/cell2state/cell2state/models/base/trainrunner.py:81, in TrainRunnerLowLevel.call(self) 78 if hasattr(self.data_splitter, "n_val"): 79 self.training_plan.n_obs_validation = self.data_splitter.n_val ---> 81 self.trainer.fit(self.training_plan, self.data_splitter) 82 self._update_history() 84 # data splitter only gets these attrs after fit
File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/scvi/train/_trainer.py:187, in Trainer.fit(self, *args, **kwargs)
181 if isinstance(args[0], PyroTrainingPlan):
182 warnings.filterwarnings(
183 action="ignore",
184 category=UserWarning,
185 message="LightningModule.configure_optimizers returned None",
186 )
--> 187 super().fit(*args, **kwargs)
File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:603, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
601 raise TypeError(f"Trainer.fit() requires a LightningModule, got: {model.class.qualname}")
602 self.strategy._lightning_module = model
--> 603 call._call_and_handle_interrupt(
604 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
605 )
File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:36, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs) 34 try: 35 if trainer.strategy.launcher is not None: ---> 36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) 37 else: 38 return trainer_fn(*args, **kwargs)
File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py:123, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs) 120 if trainer is None: 121 return worker_output --> 123 self._recover_results_in_main_process(worker_output, trainer) 124 return worker_output.trainer_results
File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py:156, in _MultiProcessingLauncher._recover_results_in_main_process(self, worker_output, trainer) 154 if worker_output.weights_path is not None: 155 ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path) --> 156 trainer.lightning_module.load_state_dict(ckpt) 157 self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path) 159 trainer.state = worker_output.trainer_state
File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/torch/nn/modules/module.py:1497, in Module.load_state_dict(self, state_dict, strict) 1492 error_msgs.insert( 1493 0, 'Missing key(s) in state_dict: {}. '.format( 1494 ', '.join('"{}"'.format(k) for k in missing_keys))) 1496 if len(error_msgs) > 0: -> 1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 1498 self.class.name, "\n\t".join(error_msgs))) 1499 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for LowLevelPyroTrainingPlan: Unexpected key(s) in state_dict: "module._model.accessibility_bias_mean_prior", ... "module._guide.scales.weights.tf_dna_binding_preference.weights.motif_weight_unconstrained" ...
<\details>
@martinkim0 Nice work! Great to have this supported.
Is it possible to combine this approach with DeviceBackedDataSplitter? I am interested in loading data subsets once - simply distributing different cells or spatial locations across GPUs.