feat(train) Add support for torch.compile (EXPERIMENTAL)
@ori-kron-wis Can you add tests for all pytorch models (not pyro and not jax) for compile. Can you check speed improvements on your end? You should execute it with: model.train(accelerator='cuda', plan_kwargs={'n_epochs_kl_warmup': 100, 'compile': True}, datasplitter_kwargs={'drop_last': True})
Needs tests like: model2.train(accelerator='cuda', batch_size=5000, max_epochs=100, train_size=0.9, plan_kwargs={'n_epochs_kl_warmup': 100, 'compile': True}, datasplitter_kwargs={'drop_last': True}) and then get_elbo, get_reconstruction_loss, get_latent.
I added torch compile tests for most models (of course not working with the github action due to that error) - on new servers, it worked fine and was faster, although the compile part will add some overhead.
Currently pyro test not working on a multi GPU machine. Need to see why (only test_pyro_bayesian_regression). once we remove it everything works (it should be passed here)
Codecov Report
Attention: Patch coverage is 33.33333% with 4 lines in your changes missing coverage. Please review.
Project coverage is 82.75%. Comparing base (
958c253) to head (3473bee). Report is 62 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| src/scvi/train/_trainingplans.py | 33.33% | 4 Missing :warning: |
:exclamation: There is a different number of reports uploaded between BASE (958c253) and HEAD (3473bee). Click for more details.
HEAD has 43 uploads less than BASE
Flag BASE (958c253) HEAD (3473bee) 46 3
Additional details and impacted files
@@ Coverage Diff @@
## main #2931 +/- ##
==========================================
- Coverage 89.78% 82.75% -7.03%
==========================================
Files 181 181
Lines 15629 15443 -186
==========================================
- Hits 14032 12780 -1252
- Misses 1597 2663 +1066
| Files with missing lines | Coverage Δ | |
|---|---|---|
| src/scvi/train/_trainingplans.py | 92.85% <33.33%> (-0.76%) |
:arrow_down: |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
This branch, not included its tests, is merged to main together with MPS fix in : https://github.com/scverse/scvi-tools/pull/3100