tutorials icon indicating copy to clipboard operation
tutorials copied to clipboard

`@torch.compile` some tutorials

Open guilhermeleobas opened this issue 7 months ago • 6 comments

Description

This PR attempts to compile three tutorials:

  • neural_tangent_kernels
  • ensembling
  • per_sample_grads

To compile with fullgraph=True, one needs pytorch with the changes from https://github.com/pytorch/pytorch/pull/129091.

Performance gain

This is not a scientific benchmark, as the inputs to the models are small and can suffer from noise. But we did have some gains on ensembling.

neural_tangent_kernels
[ empirical_ntk_jacobian_contraction ]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   18.9
      make_fx         |  149.0
      eager           |   19.1
2 threads: -------------------
      @torch.compile  |   19.1
      make_fx         |  146.6
      eager           |   19.2
4 threads: -------------------
      @torch.compile  |   18.9
      make_fx         |  148.2
      eager           |   19.2
8 threads: -------------------
      @torch.compile  |   19.1
      make_fx         |  148.6
      eager           |   19.2
16 threads: ------------------
      @torch.compile  |   19.2
      make_fx         |  149.1
      eager           |   19.1

Times are in milliseconds (ms).

[ empirical_ntk_jacobian_contraction ]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   17.6
      make_fx         |  145.1
      eager           |   17.9
2 threads: -------------------
      @torch.compile  |   17.8
      make_fx         |  142.7
      eager           |   18.0
4 threads: -------------------
      @torch.compile  |   17.7
      make_fx         |  142.3
      eager           |   18.1
8 threads: -------------------
      @torch.compile  |   17.8
      make_fx         |  144.6
      eager           |   17.8
16 threads: ------------------
      @torch.compile  |   17.7
      make_fx         |  144.8
      eager           |   18.0

Times are in milliseconds (ms).

[-- empirical_ntk_ntk_vps ---]
                      |   cuda
1 threads: -------------------
      @torch.compile  |   62.3
      make_fx         |  123.0
      eager           |   62.9
2 threads: -------------------
      @torch.compile  |   62.4
      make_fx         |  123.6
      eager           |   63.1
4 threads: -------------------
      @torch.compile  |   62.5
      make_fx         |  122.9
      eager           |   63.1
8 threads: -------------------
      @torch.compile  |   62.5
      make_fx         |  123.8
      eager           |   63.2
16 threads: ------------------
      @torch.compile  |   62.5
      make_fx         |  123.8
      eager           |   63.1

Times are in milliseconds (ms).
ensembling
[---- compute_predictions1 ----]
                      |    cuda
1 threads: ---------------------
      @torch.compile  |    149.9
      make_fx         |  14963.3
      eager           |    303.4
2 threads: ---------------------
      @torch.compile  |    151.4
      make_fx         |  14664.8
      eager           |    326.5
4 threads: ---------------------
      @torch.compile  |    152.4
      make_fx         |  14680.3
      eager           |    327.3
8 threads: ---------------------
      @torch.compile  |    164.1
      make_fx         |  14694.9
      eager           |    332.6
16 threads: --------------------
      @torch.compile  |    151.9
      make_fx         |  14633.4
      eager           |    317.7

Times are in microseconds (us).

[---- compute_predictions2 ----]
                      |    cuda
1 threads: ---------------------
      @torch.compile  |    147.0
      make_fx         |  14995.9
      eager           |    299.6
2 threads: ---------------------
      @torch.compile  |    149.0
      make_fx         |  14984.4
      eager           |    293.9
4 threads: ---------------------
      @torch.compile  |    147.1
      make_fx         |  15036.3
      eager           |    327.0
8 threads: ---------------------
      @torch.compile  |    150.4
      make_fx         |  14985.6
      eager           |    330.8
16 threads: --------------------
      @torch.compile  |    151.3
      make_fx         |  15123.6
      eager           |    301.2

Times are in microseconds (us).
per_sample_grads
[--- vmap_ft_compute_grad --]
                      |  cuda
1 threads: ------------------
      @torch.compile  |   6.5
      make_fx         |  48.1
      eager           |   7.1
2 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.8
4 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.4
8 threads: ------------------
      @torch.compile  |   6.0
      make_fx         |  47.1
      eager           |   6.8
16 threads: -----------------
      @torch.compile  |   6.0
      make_fx         |  48.1
      eager           |   6.5

Times are in milliseconds (ms).

cc @williamwen42 @msaroufim

guilhermeleobas avatar Jul 25 '24 18:07 guilhermeleobas