scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

feat(train) Add support for torch.compile (EXPERIMENTAL)

Open canergen opened this issue 1 year ago • 4 comments

canergen avatar Aug 07 '24 08:08 canergen

@ori-kron-wis Can you add tests for all pytorch models (not pyro and not jax) for compile. Can you check speed improvements on your end? You should execute it with: model.train(accelerator='cuda', plan_kwargs={'n_epochs_kl_warmup': 100, 'compile': True}, datasplitter_kwargs={'drop_last': True})

canergen avatar Aug 07 '24 19:08 canergen

Needs tests like: model2.train(accelerator='cuda', batch_size=5000, max_epochs=100, train_size=0.9, plan_kwargs={'n_epochs_kl_warmup': 100, 'compile': True}, datasplitter_kwargs={'drop_last': True}) and then get_elbo, get_reconstruction_loss, get_latent.

canergen avatar Aug 26 '24 17:08 canergen

I added torch compile tests for most models (of course not working with the github action due to that error) - on new servers, it worked fine and was faster, although the compile part will add some overhead.

Currently pyro test not working on a multi GPU machine. Need to see why (only test_pyro_bayesian_regression). once we remove it everything works (it should be passed here)

ori-kron-wis avatar Sep 17 '24 13:09 ori-kron-wis

Codecov Report

Attention: Patch coverage is 33.33333% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.75%. Comparing base (958c253) to head (3473bee). Report is 62 commits behind head on main.

Files with missing lines Patch % Lines
src/scvi/train/_trainingplans.py 33.33% 4 Missing :warning:

:exclamation: There is a different number of reports uploaded between BASE (958c253) and HEAD (3473bee). Click for more details.

HEAD has 43 uploads less than BASE
Flag BASE (958c253) HEAD (3473bee)
46 3
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2931      +/-   ##
==========================================
- Coverage   89.78%   82.75%   -7.03%     
==========================================
  Files         181      181              
  Lines       15629    15443     -186     
==========================================
- Hits        14032    12780    -1252     
- Misses       1597     2663    +1066     
Files with missing lines Coverage Δ
src/scvi/train/_trainingplans.py 92.85% <33.33%> (-0.76%) :arrow_down:

... and 31 files with indirect coverage changes

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Sep 18 '24 12:09 codecov[bot]

This branch, not included its tests, is merged to main together with MPS fix in : https://github.com/scverse/scvi-tools/pull/3100

ori-kron-wis avatar Dec 31 '24 07:12 ori-kron-wis