Using torch.compile in Pyro models
Hi @adamgayoso and others (also cc @fritzo, @martinjankowiak @eb8680)
It would be great if the new torch.compile function could be used with the Pyro model and guide in scvi-tools.
I am happy to contribute this functionality, however, I need your recommendations on what to do with the following problem. Suppose we create add torch.compile as shown below:
class MyBaseModule(PyroBaseModuleClass):
def __init__(
self,
model,
guide_class, # such as AutoNormalMessenger
**kwargs,
):
"""
Module class which defines AutoGuide given model.
"""
super().__init__()
self.hist = []
_model = model(**kwargs)
self._model = torch.compile(_model)
_guide = guide_class(**guide_kwargs)
self._guide = torch.compile(_guide)
The problem is that Pyro creates guide parameters when they are first needed - requiring these callbacks https://github.com/scverse/scvi-tools/blob/a210867bf54c43283e14b553cff473ba2c7080c5/scvi/model/base/_pyromixin.py#L19-L71. My understanding is that this means that torch.compile(_guide) should similarly be called only after the parameters are created.
I see one solution to this. Run the following code https://github.com/scverse/scvi-tools/blob/a210867bf54c43283e14b553cff473ba2c7080c5/scvi/model/base/_pyromixin.py#L65-L71 in model.train() manually without using a callback after creating data loaders but before creating TrainRunner and TrainingPlan.
Then modify the training plan as follows:
class PyroCompiledTrainingPlan(LowLevelPyroTrainingPlan):
"""
Lightning module task to train Pyro scvi-tools modules.
"""
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.module.model_compiled = torch.compile(self.module.model)
self.module.guide_compiled = torch.compile(self.module.guide)
self.svi = pyro.infer.SVI(
model=self.module.model_compiled,
guide=self.module.guide_compiled,
optim=self.optim,
loss=self.loss_fn,
)
What do you think about this? Do you have any better ideas on how to implement this?
Hi @vitkl, What's the purpose of the request? In our hands, speed-up was not reproducibly high. Have you made other experience? We see larger speed-up using JAX, so if it's about speed rewriting in numpyro should be the best option.
The purpose is to enable general support for Pyro scvi-tools models. It is possible that some models benefit from this more than other models but it's good to have this option. Pyro adds additional challenges to using torch.compile (as mentioned above) - so I was not able to try using torch.compile without additional input on how to solve those issues. Once, I understand how to implement this, I will test this for cell2location and other unpublished models which both use all GPU memory (incl multi-GPU) as opposed to scVI which mostly uses a few GB.
Re-implementation of models in numpyro is not always practical because i) numpyro doesn't cover all functionality and because ii) we observed in the past that JAX uses 2-4x of GPU memory for the same data size -meaning> less practical to use for larger datasets where every bit of GPU memory matters.
I agree that speed-up is expected to be largely model-dependent and that scVI is small and might be a bad proxy. Adam and Martin experimented with torch.compile, however, only in the pytorch models. I would expect it's more straightforward to train the model/guide for one step (similar to our current load procedure) https://github.com/scverse/scvi-tools/blob/4965279f9606216dffaa45f851e1d1ae7c886879/scvi/module/base/_base_module.py#L388. Afterwards the guide can be compiled.
Do you suggest to modify model.train() as shown below?
self.module.on_load(self)
self.module._model = torch.compile(self.module.model)
self.module._guide = torch.compile(self.module.guide)
Are self.module.model and self.module.guide modifyable as shown here?
As a proxy for compilation effect on cell2location, I can mention that our old theano+pymc3 implementation was 2-4 times faster for the same number of training steps. Would be great to see what happens here. A 2-4x speedup would be really nice.
I tried it out on my side and got some cryptic error messages (it was on a private repo with a not published model though). My idea was to call self.train(max_steps=1) once and afterwards compile. So using the guide warmup by running a single train step. I'm happy to review if you have a PR.
I will try your suggestion. Do I get this right that you suggest to
def train(self, ...):
self.train(..., max_steps=1)
self.module._model = torch.compile(self.module.model)
self.module._guide = torch.compile(self.module.guide)
self.train(...)
?
Yes, that's my understanding of how we do guide warmups for Pyro (e.g. during loading a trained model). I don't think pyro.clear_param_store() is necessary here.
This is a good point. I will test this. Lets see what happens with cell2location.
Looks like torch.compile works for cell2location using the modified train method below. Code runs but there is no speed benefit (5h12min with, 5h25min without, using torch.set_float32_matmul_precision('high') on A100 is more impactful -> 4h45min).
def MyModelClass(PyroSampleMixin, PyroSviTrainMixin, BaseModelClass):
def train_compiled(self, **kwargs):
import torch
self.train(**kwargs, max_steps=1)
self.module._model = torch.compile(self.module.model)
self.module._guide = torch.compile(self.module.guide)
self.train(**kwargs)
The model and guide are successfully replaced:
LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel(
(dropout): Dropout(p=0.0, inplace=False)
)
AutoNormal(
(locs): PyroModule()
(scales): PyroModule()
)
OptimizedModule(
(_orig_mod): LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel(
(dropout): Dropout(p=0.0, inplace=False)
)
)
OptimizedModule(
(_orig_mod): AutoNormal(
(locs): PyroModule()
(scales): PyroModule()
)
)
Pytorch documentation says (https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html):
Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.
I wonder if this means that speedups only come for models that don't already have 100% GPU utilisation. Cell2location mainly uses very large full data batches.
I also get errors if I attempt using amortised inference (using encoder NN as part of the guide).
File /nfs/team283/vk7/software/miniconda3farm5/envs/cell2loc_env_2023/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py:1544, in ShapeGuardPrinter._print_Symbol(self, expr)
1538 def repr_symbol_to_source():
1539 return repr({
1540 symbol: [s.name() for s in sources]
1541 for symbol, sources in self.symbol_to_source.items()
1542 })
-> 1544 assert self.symbol_to_source.get(expr), (
1545 f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
1546 f"not in {repr_symbol_to_source()}. If this assert is failing, it could be "
1547 "due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
1548 )
1549 return self.source_ref(self.symbol_to_source[expr][0])
AssertionError: s2 (could be from ["L['msg']['infer']['prior']._batch_shape[0]"]) not in {s0: ["L['msg']['value'].size()[0]"], s1: ["L['msg']['value'].size()[1]", "L['msg']['value'].stride()[0]"], s5: [], s2: [], s4: [], s3: []}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665