tutorials icon indicating copy to clipboard operation
tutorials copied to clipboard

`@torch.compile` some tutorials

Open guilhermeleobas opened this issue 1 year ago • 6 comments

Description

This PR attempts to compile three tutorials:

  • neural_tangent_kernels
  • ensembling
  • per_sample_grads

To compile with fullgraph=True, one needs pytorch with the changes from https://github.com/pytorch/pytorch/pull/129091.

Performance gain

This is not a scientific benchmark, as the inputs to the models are small and can suffer from noise. But we did have some gains on ensembling.

neural_tangent_kernels
[ empirical_ntk_jacobian_contraction ]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   18.9
      make_fx         |  149.0
      eager           |   19.1
2 threads: -------------------
      @torch.compile  |   19.1
      make_fx         |  146.6
      eager           |   19.2
4 threads: -------------------
      @torch.compile  |   18.9
      make_fx         |  148.2
      eager           |   19.2
8 threads: -------------------
      @torch.compile  |   19.1
      make_fx         |  148.6
      eager           |   19.2
16 threads: ------------------
      @torch.compile  |   19.2
      make_fx         |  149.1
      eager           |   19.1

Times are in milliseconds (ms).

[ empirical_ntk_jacobian_contraction ]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   17.6
      make_fx         |  145.1
      eager           |   17.9
2 threads: -------------------
      @torch.compile  |   17.8
      make_fx         |  142.7
      eager           |   18.0
4 threads: -------------------
      @torch.compile  |   17.7
      make_fx         |  142.3
      eager           |   18.1
8 threads: -------------------
      @torch.compile  |   17.8
      make_fx         |  144.6
      eager           |   17.8
16 threads: ------------------
      @torch.compile  |   17.7
      make_fx         |  144.8
      eager           |   18.0

Times are in milliseconds (ms).

[-- empirical_ntk_ntk_vps ---]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   62.3
      make_fx         |  123.0
      eager           |   62.9
2 threads: -------------------
      @torch.compile  |   62.4
      make_fx         |  123.6
      eager           |   63.1
4 threads: -------------------
      @torch.compile  |   62.5
      make_fx         |  122.9
      eager           |   63.1
8 threads: -------------------
      @torch.compile  |   62.5
      make_fx         |  123.8
      eager           |   63.2
16 threads: ------------------
      @torch.compile  |   62.5
      make_fx         |  123.8
      eager           |   63.1

Times are in milliseconds (ms).
ensembling
[---- compute_predictions1 ----]
                      |    cuda
1 threads: ---------------------
      @torch.compile  |    149.9
      make_fx         |  14963.3
      eager           |    303.4
2 threads: ---------------------
      @torch.compile  |    151.4
      make_fx         |  14664.8
      eager           |    326.5
4 threads: ---------------------
      @torch.compile  |    152.4
      make_fx         |  14680.3
      eager           |    327.3
8 threads: ---------------------
      @torch.compile  |    164.1
      make_fx         |  14694.9
      eager           |    332.6
16 threads: --------------------
      @torch.compile  |    151.9
      make_fx         |  14633.4
      eager           |    317.7

Times are in microseconds (us).

[---- compute_predictions2 ----]
                      |    cuda
1 threads: ---------------------
      @torch.compile  |    147.0
      make_fx         |  14995.9
      eager           |    299.6
2 threads: ---------------------
      @torch.compile  |    149.0
      make_fx         |  14984.4
      eager           |    293.9
4 threads: ---------------------
      @torch.compile  |    147.1
      make_fx         |  15036.3
      eager           |    327.0
8 threads: ---------------------
      @torch.compile  |    150.4
      make_fx         |  14985.6
      eager           |    330.8
16 threads: --------------------
      @torch.compile  |    151.3
      make_fx         |  15123.6
      eager           |    301.2

Times are in microseconds (us).
per_sample_grads
[--- vmap_ft_compute_grad --]
                      |  cuda
1 threads: ------------------
      @torch.compile  |   6.5
      make_fx         |  48.1
      eager           |   7.1
2 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.8
4 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.4
8 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.8
16 threads: -----------------
      @torch.compile  |   6.0
      make_fx         |  48.1
      eager           |   6.5

Times are in milliseconds (ms).

cc @williamwen42 @msaroufim

guilhermeleobas avatar Jul 25 '24 18:07 guilhermeleobas

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2984

Note: Links to docs will display an error until the docs builds have been completed.

:x: 1 New Failure

As of commit 344861dbdb68d7d708d6604221217aa2932db70f with merge base f1c0b8a3675e6a09cd3628d981f047af24746ca2 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Jul 25 '24 18:07 pytorch-bot[bot]

LGTM from a technical perspective, although is it fine to import profile_utils.py in a tutorial? @svekars

I can remove or hide the import in an environment flag.

guilhermeleobas avatar Aug 01 '24 21:08 guilhermeleobas

How will this run in Google Colab?

svekars avatar Aug 02 '24 16:08 svekars

How will this run in Google Colab?

It doesn't have to. I've removed the calls to the profiling function. It was there just to make sure we can get any speedup by compiling the model.

guilhermeleobas avatar Aug 02 '24 17:08 guilhermeleobas

I see you removed profile_utils.py from the tutorial - do we still need to add it in this PR then?

williamwen42 avatar Aug 02 '24 18:08 williamwen42

I see you removed profile_utils.py from the tutorial - do we still need to add it in this PR then?

No, I've removed it as well.

guilhermeleobas avatar Aug 02 '24 20:08 guilhermeleobas

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

github-actions[bot] avatar Oct 12 '24 00:10 github-actions[bot]