whisper icon indicating copy to clipboard operation
whisper copied to clipboard

[Do not land] [RFC] 1.375x speedup - Remove control flow from model, small hacks, enable TorchDynamo + TorchInductor

Open voznesenskym opened this issue 2 years ago • 14 comments

Obviously not meant to land, this RP is representative of what it would take to get dynamo working.

test_me.py takes 4.4 seconds on main branch test_me.py takes 3.2 seconds in this PR

Overview:

  1. I took some free audio book of chapter 1 of Charles Dickens' David Copperfield
  2. I used an mp3 splitting tool to split it into 8 parts, and then used the util in the model to get 10 chunks of 30 seconds each
  3. I "preheated" the model with audio part 0, and then ran inference on the other 9 parts
  4. With no model changes, it took 4.4 seconds
  5. After applying inductor to the forward, it segfaulted.
  6. Back to the drawing board, after applying inductor to select model parts, it took 289 seconds
  7. After rewriting the bits you see below, then applying inductor, I got it down to 3.2 seconds

Primary changes:

Remove cache to avoid control flow Remove control flow around cross entropy Duplicate multi headead attention class Selective application of inductor based on profiling + experimentation

Result:

On pre-heated inference (inference run once with a sample we do not run again) TorchDynamo makes whisper's inference 1.375x.

Followups:

There are probably a lot more wins we can get here, maybe even to 2x with careful improvements to TorchDynamo + inductor.

One of the major "fails" here is that we cannot just wrap the main entrypoint into the model with TorchDynamo. We have too many gaps atm.

voznesenskym avatar Sep 22 '22 20:09 voznesenskym

For those who haven't heard of TorchDynamo/TorchInductor, it is automatically fusing and mapping PyTorch to Triton.

jansel avatar Sep 23 '22 02:09 jansel

This sounds great! I was also wondering how fast it'd be if Triton's flash attention was integrated, but unfortunately it's A100 only.

Implementation-wise, I think we could subclass the class PyTorchInference(Inference): and monkey-patch the attention layers only when TorchDynamo is available, so that the code is still usable in the older PyTorch versions.

jongwook avatar Sep 23 '22 03:09 jongwook

This sounds great! I was also wondering how fast it'd be if Triton's flash attention was integrated, but unfortunately it's A100 only.

Implementation-wise, I think we could subclass the class PyTorchInference(Inference): and monkey-patch the attention layers only when TorchDynamo is available, so that the code is still usable in the older PyTorch versions.

That sounds awesome. I'd love to do that. I'm in the pytorch slack and the triton slack - would you like to chat there? I also have further questions on getting a little bit of realistic inference data so we can setup a benchmark on our end (As well as to better measure accuracy). The adhoc free audiobook approach isn't scaling super well, hah.

voznesenskym avatar Sep 23 '22 05:09 voznesenskym

do you have more benchmarks? for example, on cpu.

taylorchu avatar Sep 23 '22 22:09 taylorchu

do you have more benchmarks? for example, on cpu.

No, I am sorry, I do not. I plan on working with @jongwook to benchmark this properly :)

voznesenskym avatar Sep 25 '22 02:09 voznesenskym

This is the RFC - The implementation PR will be here: https://github.com/openai/whisper/pull/115

voznesenskym avatar Sep 25 '22 02:09 voznesenskym

@voznesenskym I am curious why you guys started with "torchdynamo" instead of more widely-adopted "torchscript". We are in the process of making this torch.jit compatible, so I was wondering whether torchscript is slower in comparison.

taylorchu avatar Sep 25 '22 03:09 taylorchu

@voznesenskym I am curious why you guys started with "torchdynamo" instead of more widely-adopted "torchscript". We are in the process of making this torch.jit compatible, so I was wondering whether torchscript is slower in comparison.

TorchDynamo and torchscript are fundamentally different projects, and we are investing in TorchDynamo as a next gen core component of our stack. While their surface levels goals (in this case, speed) align, they are rather different. I am happy to go into it more, but the ReadMe in the TorchDynamo project goes into great depths about what the project is. Have you had a chance to read that yet?

voznesenskym avatar Sep 25 '22 04:09 voznesenskym

@taylorchu I definitely recommend in going the torchdynamo route than torch.jit. it's more aligned with our future plans.

soumith avatar Sep 25 '22 13:09 soumith

@soumith @voznesenskym is the torch team plan for torchdynamo or torch.jit written some where?

I am interested in whether one will choose one over the other in certain use cases.

taylorchu avatar Sep 26 '22 21:09 taylorchu

not written down anywhere concretely, we'll talk about it in a few months. But about dynamo itself, we have quite a few posts here with various updates: https://dev-discuss.pytorch.org/

soumith avatar Sep 26 '22 21:09 soumith

Just in case. I can provide a large set of data transcribed by whisper so that you guys can validate whether the change affects the text output.

nlgtuankiet avatar Oct 02 '22 11:10 nlgtuankiet

@voznesenskym I am trying to benchmark your approach with torchdynamo but got some error modules. do you know which version torchinductor, torchdynamo and triton are used to make your modification work?

Shiro-LK avatar Oct 06 '22 16:10 Shiro-LK

@voznesenskym I am trying to benchmark your approach with torchdynamo but got some error modules. do you know which version torchinductor, torchdynamo and triton are used to make your modification work?

Hey, dynamo migrated to latest triton, so we maybe have some new errors here, but the torchdynamo Makefile https://github.com/pytorch/torchdynamo/blob/main/Makefile has the versions of all our deps (usually cutting edge nightlies).

I plan to revisit this shortly, and will fix up any errors I find.

voznesenskym avatar Oct 13 '22 04:10 voznesenskym

Thanks, I'll close this for now, since it doesn't quite yet work "out of the box" and relying on nightly versions makes things difficult for me to maintain. I'm hoping to get an easier integration with the stable PyTorch 2 interface once it's out.

jongwook avatar Dec 04 '22 23:12 jongwook