transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Flax vs torch benchmark on Wav2vec2

Open ZurabDz opened this issue 3 years ago • 6 comments

So my question is, should FlaxWav2Vec2ForCTC generally be faster than Wav2Vec2ForCTC?

1.14 s ± 138 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> FlaxWav2Vec2ForCTC

37.7 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> Wav2Vec2ForCTC

so the question is, should flax be faster than the default torch model?

P.S: benchmarks are done on GPU, it seems like VRAM usage is drastically larger on flax for some reason as well.

ZurabDz avatar Sep 19 '22 14:09 ZurabDz

maybe of interest to @sanchit-gandhi

LysandreJik avatar Sep 19 '22 22:09 LysandreJik

Hey @ZurabDz! FlaxWav2Vec2ForCTC should be faster than Wav2Vec2ForCTC if the __call__ method is just in time (JIT) compiled (c.f. https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). Could you share your code for running this benchmark? We can then go through and make sure the Flax model is appropriately set-up to get max performance!

Also of interest: this notebook which JIT compiles the __call__ method for BLOOM https://github.com/sanchit-gandhi/codesnippets/blob/main/check_flax_bloom_jit_small_testing.ipynb You can see the speed up you get by JIT'ing the fprop! We can do something similar for your benchmark, comparing the iteration time of PyTorch to Flax (rather than the accuracy).

sanchit-gandhi avatar Sep 20 '22 12:09 sanchit-gandhi

@sanchit-gandhi So the code I use was something like this:

'''
  Trying inference with jax. Note: it errors out without modifying source code currently
  I was only concerned with speed so just silent errors: 
  assigned self.config.do_stable_layer_norm = True in modeling_flax_wav2vec2.py
  assigned self.config.feat_extract_norm = "layer"
'''
from transformers import Wav2Vec2Processor,  FlaxWav2Vec2ForCTC

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", from_pt=True)

sig, sr = torchaudio.load('out.mp3')  # Make sure you have linux and ffmpeg 4 or use wav/mp3 format + soundfile/librosa
# preprocess, this is computed in prefetch don't care what time will it take...(in my pipeline)
input_values = processor(sig[0], sampling_rate=16_000, return_tensors="pt").input_values  

%%timeit  # jupyter magic or you could use time 
logits = model(input_values).logits
''' 
   Just standard inference nothing fancy
'''
from transformers import Wav2Vec2Processor,  Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

sig, sr = torchaudio.load('out.mp3')  # Make sure you have linux and ffmpeg 4 or use wav format + soundfile/librosa
# preprocess, this is computed in prefetch don't care what time will it take...(in my pipeline)
input_values = processor(sig[0], sampling_rate=16_000, return_tensors="pt").input_values  

%%timeit  # jupyter magic or you could use time 
logits = model(input_values).logits

now whats interesting is that, with flax inference GPU utilisation is jumpy from 0-20% there might be some problem in memory allocation on cuda idk...

Tried this:

@jax.jit
def flax_model_jitted(input_values):
    return model(input_values).logits

seems like jit expects known type for flax so, added something like this as well input_values = numpy.array(input_values) in this case GPU was not used. On CPU speed up definitely is present.

I installed cuda, cudnn, flax and jax with following way:

conda install -c conda-forge cudatoolkit-dev=11.2 cudnn=8.2.0
pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

P.s What do you guys think about custom forward written in cuda vs flax assuming it(flax) performs to its peak?

ZurabDz avatar Sep 20 '22 19:09 ZurabDz

Oops accidentally closed an issue sorry guys

ZurabDz avatar Sep 20 '22 19:09 ZurabDz

in this case GPU was not used

You can verify that you're running on an accelerator device by checking the number of JAX devices:

print(jax.device_count())

This will tell you if you're on CPU or GPU!

On CPU speed up definitely is present.

Did you make sure to use .block_until_ready() on the output logits? https://jax.readthedocs.io/en/latest/async_dispatch.html

Perhaps you could post your full code snippet for the JIT benchmark!

I'd do something as follows:

@jax.jit
def flax_model_jitted(input_values):
    return model(input_values).logits

input_values = jnp.array(input_values)
# Compilation time (should be ~s)
%time logits = flax_model_jitted(input_values=input_values).block_until_ready()
# Compiled time (should be ~ms)
%time logits = flax_model_jitted(input_values=input_values).block_until_ready()

You can refer to the ipynb for a template on how to set up a performance test: https://github.com/sanchit-gandhi/codesnippets/blob/main/check_flax_bloom_jit_small_testing.ipynb

sanchit-gandhi avatar Sep 21 '22 09:09 sanchit-gandhi

import jax

# This prints 1
print(jax.device_count(backend='gpu'))

Unfortunately, GPU utilisation is still 0% which means inference is still done on CPU. Memory is definitely allocated when model is loaded but after that, nothing really happens on it. Currently flax benchmark looks like this:

from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
import torch
import torchaudio
import jax
from jax import numpy

print(jax.device_count(backend='gpu')) # this prints 1

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", from_pt=True)

sig, sr = torchaudio.load('out.mp3')
input_values = processor(sig[0], sampling_rate=16_000, return_tensors="pt").input_values 
input_values = numpy.array(input_values)

@jax.jit
def flax_model_jitted(input_values):
    return model(input_values).logits
    
%%timeit
logits = flax_model_jitted(input_values=input_values).block_until_ready()
# 90.8 ms ± 41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
logits = flax_model_jitted(input_values=input_values).block_until_ready() 
# 62.8 ms ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Is something wrong with jax.numpy.array? Do I need to somehow force things into GPU? print(jax.device_count(backend='gpu')) # this prints 1 is this what I should be expecting for GPU usage?

ZurabDz avatar Sep 21 '22 11:09 ZurabDz

@sanchit-gandhi sorry for pinging, but any thoughts on what could be the reasoning for such weird results?

ZurabDz avatar Sep 23 '22 20:09 ZurabDz

Hey @ZurabDz! Sorry for the late reply. It looks like JAX is recognising your GPU which is good! The problem likely lies in your preparation of the inputs. First, what I'd try is returning the input values as np arrays:

import jax.numpy as jnp

sig, sr = torchaudio.load("out.mp3")
input_values = processor(sig[0], sampling_rate=16_000, return_tensors="np").input_values 
input_values_jnp = jnp.array(input_values)

and then pass these to the model.

If that does not help, then you can try using device_put() as explained in multiplying-matrices.

sanchit-gandhi avatar Oct 10 '22 11:10 sanchit-gandhi

Sorry, currently I am unable to test device_put I am occupied with a different problem. Maybe we should close an issue and open it later if the problem persists.

ZurabDz avatar Oct 18 '22 10:10 ZurabDz

Hey @ZurabDz! Sure, let's close it for now and re-open if you continue to encounter this problem. Feel free to open a new issue for the different problem you are facing and tag me!

sanchit-gandhi avatar Oct 18 '22 11:10 sanchit-gandhi