TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

Performance of structured sparsity for Inceptionv3 on A6000 GPU

Open shivmgg opened this issue 3 years ago • 5 comments

Hi,

I trained and ran the dense and 2:4 structured sparse model for the InceptionV3 model on the NVIDIA A6000 GPU. However, I observe marginal performance improvement and even worse runtime with larger batch sizes.

Here are the latency numbers (in ms) for various batch sizes.

Batch Size 32 64 128 256
Dense 10.4 19.83 37.94 74.73
2:4 structured model 9.97 19.08 36.69 91.86

I used the following model files for batch size 32: Dense Model: https://drive.google.com/file/d/1UtARXRPJZNs4_6Hsv9BcUmxP6V6GAaPn/view?usp=sharing Sparse Model: https://drive.google.com/file/d/1Czbgt7yV_LfBmtlU7XvtF1by1Js0upwZ/view?usp=sharing

I used the following command to run the model with TensorRT: trtexec --onnx=inception_sparse_model.onnx --fp16 --saveEngine=inception_pytorch_sparse.trt --sparsity=enable --explicitBatch --workspace=40000

Please let me know if the performance can be improved.

shivmgg avatar Sep 08 '22 09:09 shivmgg

What I know is usually we only get about 15%(usually less) e2e perf improvement with sparsity enabled for most of the CNN-based models. sometimes the speedup is not obvious because we have well-optimized dense kernels.

@shivmgg Does your dense model and sparse model have the exactly same structure? Can you try running your dense model with --sparsity=force and compare it to the normal run?

@nvpohanh any comments here?

zerollzeng avatar Sep 09 '22 07:09 zerollzeng

Here are the latency numbers (in ms) with --sparsity=force for the dense model.

Batch Size 32 64 128 256
Dense 10.4 19.83 37.94 74.73
Dense (with --sparsity=forced) 9.99 19.09 36.84 71.55
2:4 structured model 9.97 19.08 36.69 91.86

The two models have exactly same structure. I didn't get even 10% performance improvement. Please let me know what can be done to improve the performance.

If the dense kernels are well-optimized, what is the best possible way to compare the model performance of sparse models with dense models? Can I disable any optimizations for the dense model that are not enabled for the sparse counterpart?

shivmgg avatar Sep 09 '22 09:09 shivmgg

I think what you see might be expected. I've done some perf tests with some public models and most of them get perf improvment of less than 10%(especially when the model is small. almost perf improvement). Something worse to mention here:

  1. only some conv layers can be running with sparsity.
  2. sparsity only shows good perf improvement when the conv layer is big. e.g. the channel is larger than 256.
  3. even with --sparsity=force TRT won't choose the sparse tactic if other dense tactics is faster. you can see it from TRT's verbose log.

If the dense kernels are well-optimized, what is the best possible way to compare the model performance of sparse models with dense models? Can I disable any optimizations for the dense model that are not enabled for the sparse counterpart?

Unfortunately, you can't.

@nvpohanh I think we can add a section about the sparsity performance to our developer guide.

zerollzeng avatar Sep 09 '22 09:09 zerollzeng

Thanks a lot for the info!

May I know what kinds of optimizations are present in the dense kernels that make them better than their sparse counterpart? Theoretically speaking, there should be about 2x performance improvement with the sparse implementation in terms of the number of operations performed. However, clearly, something is missing.

If there is very minimal advantage of using sparse tensor cores on large GPUs such as A6000 and A100, there would not be any point exploring sparse neural networks on recently released edge GPUs such as Jetson Orin modules, right?

shivmgg avatar Sep 12 '22 03:09 shivmgg

May I know what kinds of optimizations are present in the dense kernels that make them better than their sparse counterpart?

Lots of :-)

If there is very minimal advantage of using sparse tensor cores on large GPUs such as A6000 and A100, there would not be any point exploring sparse neural networks on recently released edge GPUs such as Jetson Orin modules, right?

I think the answer is yes, but I cannot fully confirm it. @nvpohanh may have a clear answer.

zerollzeng avatar Sep 13 '22 03:09 zerollzeng

In TRT 8.5 GA, we added a few optimized sparse INT8 3x3 Convs, so if INT8 is an option, you can give it a try.

Generally speaking, InceptionV3 is too small and does not utilize GPU enough so sparsity doesn't give much benefit. Base on my experience, sparsity only has benefit if the input channel and the output channel is 256 or greater. Examples that show good sparsity perf are VGG16/19 and BERT, which have much heavier convolution computations.

nvpohanh avatar Dec 02 '22 09:12 nvpohanh

Closing since no activity for more than 3 weeks, please reopen if you still have question, thanks!

ttyio avatar Dec 27 '22 01:12 ttyio

@shivmgg would it be possible to reshare your models. If in the structured sparsification process you are only zeroing the column than I would assume you still wont get a big speedup as somehow the operation is still computed. For the same model deesparse ( a neural engine ) showed a speedup of 3.6 x compared to 1.1 inferring on GPU. However for a different architecture containing CNN layers where i was able to filter them out i got a decent speedup ( 99% sparse CNN showed a speedup of 6x

Heatdh avatar Nov 17 '23 10:11 Heatdh