scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

Fixes to pyro model initialisation & sampling [WIP]

Open vitkl opened this issue 1 year ago • 7 comments

Addresses https://github.com/scverse/scvi-tools/issues/2616

Replaces https://github.com/scverse/scvi-tools/pull/1805

vitkl avatar Apr 07 '24 23:04 vitkl

I don't fully understand the reason for the errors - they don't happen in test_pyro_bayesian_regression_low_level, test_pyro_bayesian_regression, test_pyro_bayesian_regression_jit - but they happen when using train() directly. This approach works for cell2location.

The difference maybe the timing when the plates are first used. I will look into this later.

vitkl avatar Apr 08 '24 01:04 vitkl

Also this code for posterior sampling is indeed ~2-3x faster but it creates samples of huge observed data matrixes (copies data n_samples times - eg 1000):

        if isinstance(self.module.guide, poutine.messenger.Messenger):
            # This already includes trace-replay behavior.
            sample = self.module.guide(*args, **kwargs)

An alternative way to deal with this issue would be this:

        if isinstance(self.module.guide, poutine.messenger.Messenger):
            # This already includes trace-replay behavior.
            sample = self.module.guide(*args, **kwargs)
            # include and exclude requested sites
            sample = {k: v for k, v in sample.items() if k in return_sites}
            sample = {k: v for k, v in sample.items() if k not in exclude_vars}   # this has to be provided by model developer

@martinkim0 What do you think we should do? What do you think about the initialisation solution?

vitkl avatar Apr 08 '24 01:04 vitkl

@vitkl hey sorry for the delay, I'm planning on taking a look at this tomorrow!

martinkim0 avatar Apr 12 '24 01:04 martinkim0

This is actually my first time at taking a look at some of our Pyro code - I hadn't really interacted with it before. So I don't really understand the reason why some things are done, e.g., the warmup callbacks. I definitely need to take a deep dive into all of this.

However, it looks like both PyroJitGuideWarmup and PyroModelGuideWarmup are just passing in a single minibatch through the guide prior to the training loop, so I like the idea of having a method like setup_pyro_model that does this. I think this makes more sense in the training plan though, using one of the Lightning hooks such as on_train_start. And there's definitely something weird going on with tensors on different devices, and I think using one of the Lightning hooks would solve this since their backend will take care of moving tensors.

Regarding the sampling changes, would it be possible to include that in a separate PR? And then we can discuss that there. Thanks!

martinkim0 avatar Apr 12 '24 21:04 martinkim0

Just a brief reply. Happy to have a zoom call about pyro.

Pyro automatic variational distribution (Guide) doesn’t have any parameters until you do a first pass through the model and guide. When moving my code to multi-GPU training I found that this needs to be done in setup step of the Lightning workflow - otherwise parameters created on GPU don’t get moved between devices correctly - so it’s it would not in on_train_start. However, in the latest version the setup step also doesn’t work - as reported in the original issue. Moving the code to this function and calling it before using any Lightning workflow steps seems to solve the problem for cell2location and my other project.

Actually the reason for the errors might be resolved if you call both the model and guide with one batch (it’s possibly the issue with LDA model that uses a custom guide).

vitkl avatar Apr 12 '24 22:04 vitkl

Please split into two PRs. One for the warmup changes and one for the inference changes. This makes it easier to follow changes.

canergen avatar Sep 05 '24 18:09 canergen