tutorials
tutorials copied to clipboard
`@torch.compile` some tutorials
Description
This PR attempts to compile three tutorials:
- neural_tangent_kernels
- ensembling
- per_sample_grads
To compile with fullgraph=True
, one needs pytorch with the changes from https://github.com/pytorch/pytorch/pull/129091.
Performance gain
This is not a scientific benchmark, as the inputs to the models are small and can suffer from noise. But we did have some gains on ensembling
.
neural_tangent_kernels
[ empirical_ntk_jacobian_contraction ]
| cuda
1 threads: -------------------
@torch.compile | 18.9
make_fx | 149.0
eager | 19.1
2 threads: -------------------
@torch.compile | 19.1
make_fx | 146.6
eager | 19.2
4 threads: -------------------
@torch.compile | 18.9
make_fx | 148.2
eager | 19.2
8 threads: -------------------
@torch.compile | 19.1
make_fx | 148.6
eager | 19.2
16 threads: ------------------
@torch.compile | 19.2
make_fx | 149.1
eager | 19.1
Times are in milliseconds (ms).
[ empirical_ntk_jacobian_contraction ]
| cuda
1 threads: -------------------
@torch.compile | 17.6
make_fx | 145.1
eager | 17.9
2 threads: -------------------
@torch.compile | 17.8
make_fx | 142.7
eager | 18.0
4 threads: -------------------
@torch.compile | 17.7
make_fx | 142.3
eager | 18.1
8 threads: -------------------
@torch.compile | 17.8
make_fx | 144.6
eager | 17.8
16 threads: ------------------
@torch.compile | 17.7
make_fx | 144.8
eager | 18.0
Times are in milliseconds (ms).
[-- empirical_ntk_ntk_vps ---]
| cuda
1 threads: -------------------
@torch.compile | 62.3
make_fx | 123.0
eager | 62.9
2 threads: -------------------
@torch.compile | 62.4
make_fx | 123.6
eager | 63.1
4 threads: -------------------
@torch.compile | 62.5
make_fx | 122.9
eager | 63.1
8 threads: -------------------
@torch.compile | 62.5
make_fx | 123.8
eager | 63.2
16 threads: ------------------
@torch.compile | 62.5
make_fx | 123.8
eager | 63.1
Times are in milliseconds (ms).
ensembling
[---- compute_predictions1 ----]
| cuda
1 threads: ---------------------
@torch.compile | 149.9
make_fx | 14963.3
eager | 303.4
2 threads: ---------------------
@torch.compile | 151.4
make_fx | 14664.8
eager | 326.5
4 threads: ---------------------
@torch.compile | 152.4
make_fx | 14680.3
eager | 327.3
8 threads: ---------------------
@torch.compile | 164.1
make_fx | 14694.9
eager | 332.6
16 threads: --------------------
@torch.compile | 151.9
make_fx | 14633.4
eager | 317.7
Times are in microseconds (us).
[---- compute_predictions2 ----]
| cuda
1 threads: ---------------------
@torch.compile | 147.0
make_fx | 14995.9
eager | 299.6
2 threads: ---------------------
@torch.compile | 149.0
make_fx | 14984.4
eager | 293.9
4 threads: ---------------------
@torch.compile | 147.1
make_fx | 15036.3
eager | 327.0
8 threads: ---------------------
@torch.compile | 150.4
make_fx | 14985.6
eager | 330.8
16 threads: --------------------
@torch.compile | 151.3
make_fx | 15123.6
eager | 301.2
Times are in microseconds (us).
per_sample_grads
[--- vmap_ft_compute_grad --]
| cuda
1 threads: ------------------
@torch.compile | 6.5
make_fx | 48.1
eager | 7.1
2 threads: ------------------
@torch.compile | 6.0
make_fx | 47.1
eager | 6.8
4 threads: ------------------
@torch.compile | 6.0
make_fx | 47.1
eager | 6.4
8 threads: ------------------
@torch.compile | 6.0
make_fx | 47.1
eager | 6.8
16 threads: -----------------
@torch.compile | 6.0
make_fx | 48.1
eager | 6.5
Times are in milliseconds (ms).
cc @williamwen42 @msaroufim