transformers
transformers copied to clipboard
Flax vs torch benchmark on Wav2vec2
So my question is, should FlaxWav2Vec2ForCTC generally be faster than Wav2Vec2ForCTC?
1.14 s ± 138 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> FlaxWav2Vec2ForCTC
37.7 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> Wav2Vec2ForCTC
so the question is, should flax be faster than the default torch model?
P.S: benchmarks are done on GPU, it seems like VRAM usage is drastically larger on flax for some reason as well.
maybe of interest to @sanchit-gandhi
Hey @ZurabDz! FlaxWav2Vec2ForCTC should be faster than Wav2Vec2ForCTC if the __call__ method is just in time (JIT) compiled (c.f. https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). Could you share your code for running this benchmark? We can then go through and make sure the Flax model is appropriately set-up to get max performance!
Also of interest: this notebook which JIT compiles the __call__ method for BLOOM https://github.com/sanchit-gandhi/codesnippets/blob/main/check_flax_bloom_jit_small_testing.ipynb
You can see the speed up you get by JIT'ing the fprop! We can do something similar for your benchmark, comparing the iteration time of PyTorch to Flax (rather than the accuracy).
@sanchit-gandhi So the code I use was something like this:
'''
Trying inference with jax. Note: it errors out without modifying source code currently
I was only concerned with speed so just silent errors:
assigned self.config.do_stable_layer_norm = True in modeling_flax_wav2vec2.py
assigned self.config.feat_extract_norm = "layer"
'''
from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", from_pt=True)
sig, sr = torchaudio.load('out.mp3') # Make sure you have linux and ffmpeg 4 or use wav/mp3 format + soundfile/librosa
# preprocess, this is computed in prefetch don't care what time will it take...(in my pipeline)
input_values = processor(sig[0], sampling_rate=16_000, return_tensors="pt").input_values
%%timeit # jupyter magic or you could use time
logits = model(input_values).logits
'''
Just standard inference nothing fancy
'''
from transformers import Wav2Vec2Processor, Wav2Vec2Processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
sig, sr = torchaudio.load('out.mp3') # Make sure you have linux and ffmpeg 4 or use wav format + soundfile/librosa
# preprocess, this is computed in prefetch don't care what time will it take...(in my pipeline)
input_values = processor(sig[0], sampling_rate=16_000, return_tensors="pt").input_values
%%timeit # jupyter magic or you could use time
logits = model(input_values).logits
now whats interesting is that, with flax inference GPU utilisation is jumpy from 0-20% there might be some problem in memory allocation on cuda idk...
Tried this:
@jax.jit
def flax_model_jitted(input_values):
return model(input_values).logits
seems like jit expects known type for flax so, added something like this as well input_values = numpy.array(input_values)
in this case GPU was not used. On CPU speed up definitely is present.
I installed cuda, cudnn, flax and jax with following way:
conda install -c conda-forge cudatoolkit-dev=11.2 cudnn=8.2.0
pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
P.s What do you guys think about custom forward written in cuda vs flax assuming it(flax) performs to its peak?
Oops accidentally closed an issue sorry guys
in this case GPU was not used
You can verify that you're running on an accelerator device by checking the number of JAX devices:
print(jax.device_count())
This will tell you if you're on CPU or GPU!
On CPU speed up definitely is present.
Did you make sure to use .block_until_ready() on the output logits? https://jax.readthedocs.io/en/latest/async_dispatch.html
Perhaps you could post your full code snippet for the JIT benchmark!
I'd do something as follows:
@jax.jit
def flax_model_jitted(input_values):
return model(input_values).logits
input_values = jnp.array(input_values)
# Compilation time (should be ~s)
%time logits = flax_model_jitted(input_values=input_values).block_until_ready()
# Compiled time (should be ~ms)
%time logits = flax_model_jitted(input_values=input_values).block_until_ready()
You can refer to the ipynb for a template on how to set up a performance test: https://github.com/sanchit-gandhi/codesnippets/blob/main/check_flax_bloom_jit_small_testing.ipynb
import jax
# This prints 1
print(jax.device_count(backend='gpu'))
Unfortunately, GPU utilisation is still 0% which means inference is still done on CPU. Memory is definitely allocated when model is loaded but after that, nothing really happens on it. Currently flax benchmark looks like this:
from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
import torch
import torchaudio
import jax
from jax import numpy
print(jax.device_count(backend='gpu')) # this prints 1
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", from_pt=True)
sig, sr = torchaudio.load('out.mp3')
input_values = processor(sig[0], sampling_rate=16_000, return_tensors="pt").input_values
input_values = numpy.array(input_values)
@jax.jit
def flax_model_jitted(input_values):
return model(input_values).logits
%%timeit
logits = flax_model_jitted(input_values=input_values).block_until_ready()
# 90.8 ms ± 41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
logits = flax_model_jitted(input_values=input_values).block_until_ready()
# 62.8 ms ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Is something wrong with jax.numpy.array? Do I need to somehow force things into GPU?
print(jax.device_count(backend='gpu')) # this prints 1 is this what I should be expecting for GPU usage?
@sanchit-gandhi sorry for pinging, but any thoughts on what could be the reasoning for such weird results?
Hey @ZurabDz! Sorry for the late reply. It looks like JAX is recognising your GPU which is good! The problem likely lies in your preparation of the inputs. First, what I'd try is returning the input values as np arrays:
import jax.numpy as jnp
sig, sr = torchaudio.load("out.mp3")
input_values = processor(sig[0], sampling_rate=16_000, return_tensors="np").input_values
input_values_jnp = jnp.array(input_values)
and then pass these to the model.
If that does not help, then you can try using device_put() as explained in multiplying-matrices.
Sorry, currently I am unable to test device_put I am occupied with a different problem. Maybe we should close an issue and open it later if the problem persists.
Hey @ZurabDz! Sure, let's close it for now and re-open if you continue to encounter this problem. Feel free to open a new issue for the different problem you are facing and tag me!