pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

Add CUDA Graph and AOT Autograd support

Open xwang233 opened this issue 2 years ago • 7 comments

Add CUDA Graph support with --cuda-graph and AOT Autograd support with --aot-autograd to benchmark.py and train.py

The workflow for cuda graph in train.py might be a bit overcomplicated.

Related: https://github.com/rwightman/pytorch-image-models/issues/1244

xwang233 avatar May 24 '22 18:05 xwang233

I'm still working on extra benchmark and accuracy test for the new options at this moment.

xwang233 avatar May 24 '22 18:05 xwang233

cc @rwightman @csarofeen @ptrblck @kevinstephano @jjsjann123

xwang233 avatar May 24 '22 18:05 xwang233

FYI I intend to review (can't set myself as a reviewer)

csarofeen avatar May 24 '22 19:05 csarofeen

FYI I intend to review (can't set myself as a reviewer)

Seems I can't add you as a formal reviewer either, might require reviewer to be added as collaborator. Hmm, I thought only read-only access was needed...

rwightman avatar May 24 '22 20:05 rwightman

I have some ResNet50 AMP+channels-last training results with cuda graph and nvfuser that verified the training accuracy (loss, val acc) here https://gist.github.com/xwang233/f3b5b4818762b08d716f969899b6d263.

After 10 epochs,

V100x8, BS = 128

mode throughput eval top1
Eager 4183.51 43.13
Cuda graph 4141.28 42.66
Cuda graph + nvfuser 4180.31 42.74

A100x8, BS = 32

mode throughput eval top1
Eager 5665.04 59.296
Cuda graph 6630.37 59.2
Cuda graph + nvfuser 6672.33 59.274

TL;DR: the training accuracies are the same for eager mode, cuda graph, and cuda graph + nvfuser. Cuda graph keeps the training throughput the same at large batch size, but can get out-of-the-box improvements on small batch size. For example, in the results shown above, ResNet50 on A100x8 with batch size = 32 got training throughput improvements from 5600 -> 6600 img/s.

I'm also checking training accuracy with aot_autograd.

xwang233 avatar May 25 '22 17:05 xwang233

ResNet50 FP32 training results with eager, cuda graph, TS+nvfuser, AOT_autograd+nvfuser https://gist.github.com/xwang233/d5136facb3361af54693081da346fd33

After 10 epochs,

A100x8, BS = 128

mode throughput eval top1
Eager 6499.87 38.21
Cuda graph 6608.18 43.19
TorchScript + nvfuser 6453.39 38.77
AOT_autograd + nvfuser 6887.96 38.33

A100x8, BS = 32

mode throughput eval top1
Eager 5130.81 59.50
Cuda graph 5833.69 57.36
TorchScript + nvfuser 4986.46 59.50
AOT_autograd + nvfuser 5228.49 59.39

V100x8, BS = 64

mode throughput eval top1
Eager 2573.18 51.29
Cuda graph 2653.16 53.00
TorchScript + nvfuser 2581.52 51.19
AOT_autograd + nvfuser 2687.08 51.24

xwang233 avatar May 26 '22 20:05 xwang233

@csarofeen @kevinstephano @xwang233 putting a few comments down here that relate to the whole PR One of the reasons I haven't put time into exploring the graph replay in train script up until now is that it was clear it's a a fair bit of very specific code that will make quite a mess of the train loop and setup code...

It's nice to see it together but not sure it's worth it just yet, it really needs to be pushed into a model / task wrapper. I had a plan to work it into the bits_and_tpu branch (https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits) that I've been using for PT XLA TPU training (to be merged some day to master). There are GPU (CUDA), XLA specific interfaces for device, distributed primitives, and optimizer / step updates... I need to further refine it to cover DeepSpeed though. I feel graph mode would make sense a state machine within the class wrapping optimizer/step (Updater).

I have to think if there's a way to have the graph code in this train script that better separates the extra code (even if it adds redundancy)...

rwightman avatar May 26 '22 21:05 rwightman