transformers
transformers copied to clipboard
Flax Whisper uses a lot of GPU memory
I'm using Flax whisper-medium and now it's ~3x faster rather than the pytorch deployment. but now it is allocating ~10x more GPU memory. loading Pytorch model takes ~3GB, but loading Flax Whisper-medium takes >30GB of VRAM. Does this huge memory allocation normal? And is there any prepared method for cut it down? @andyehrenberg @ArthurZucker @sanchit-gandhi
The code for loading Flax model:
with torch.no_grad():
model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, from_pt=True)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language"])
Hey @hannan72! Could you try disabling _do_init
? This way we won't initialise a random version of the parameters. Note that this isn't compatible with from_pt=True
, so you'll have to load a checkpoint where the Flax weights have already been saved:
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, _do_init=False)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language"])
input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, params=params, max_length=128, language='<|en|>') # we need to explicitly pass the params now since we're in Flax's stateless design
If you need to load a model where you only have PyTorch weights, you can first convert them to Flax on CPU:
import jax
# Global flag to set a specific platform, must be used at startup. ONLY DO THIS FOR SAVING WEIGHTS ON CPU!
jax.config.update('jax_platform_name', 'cpu')
model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, from_pt=True)
model.save_pretrained("save/path/to/ckpt/here")
Kill this window, and then open up a new one and load:
model, params = FlaxWhisperForConditionalGeneration.from_pretrained("save/path/to/ckpt/here", dtype=jnp.float16, _do_init=False)
Thanks for your response @sanchit-gandhi
I've tested your proposed approach, save flax model by converting to cpu and then restart kernel and load FlaxWhisperForConditionalGeneration
by try disabling _do_init
.
But Inference time increased a lot while GPU memory utilization didn't decreased significantly.
results when use from_pt=True
for whisper-medium on a 10 second audio on A100-40GB GPU:
- GPU memory usage: ~33.1GB
- Inference time: ~0.22 seconds
results when use _do_init=False
for flax saved whisper-medium on a 10 second audio on A100-40GB GPU:
- GPU memory usage: ~31.1GB
- Inference time: ~16.5 seconds
Now Inference time is 80x larger!
Some of the extra GPU memory can probably be attributed to how the flax generation implements the kv cache. Check what happens when you set max new tokens to be smaller.
Also, it doesn't make sense to run the flax stuff within a torch.no_grad()
context.
I also found that whisper_small checkpoint is also taking ~33GB of GPU RAM!
For my fine-tuned whisper-medium, if I don't run inside the torch.no_grad()
, I get an error and it is just fixed by adding torch.no_grad()
:
RuntimeError Traceback (most recent call last)
/s2t-test/client_notebook/Untitled1.ipynb Cell 25 in <cell line: 3>()
1 jax.config.update('jax_platform_name', 'cpu')
----> 2 model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id , dtype=jnp.float16, from_pt=True)
3 model.save_pretrained(model_id+ "/flax/")
File /opt/conda/lib/python3.8/site-packages/transformers/modeling_flax_utils.py:810, in FlaxPreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
807 model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
809 if from_pt:
--> 810 state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
811 else:
812 if is_sharded:
File /opt/conda/lib/python3.8/site-packages/transformers/modeling_flax_pytorch_utils.py:62, in load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys)
59 pt_state_dict = torch.load(pt_path, map_location="cpu")
60 logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
---> 62 flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
63 else:
64 # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
65 flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
File /opt/conda/lib/python3.8/site-packages/transformers/modeling_flax_pytorch_utils.py:128, in convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
126 def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
...
--> 128 pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
130 model_prefix = flax_model.base_model_prefix
132 # use params dict if the model contains batch norm layers
RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
(However pretrained models does not need to be loaded inside torch.no_grad()
)
Albeit the results I mentioned after @sanchit-gandhi 's answer, was test with and without torch.no_grad()
and it didn't make any change.
Now Inference time is 80x larger!
There shouldn't be any difference to inference time - are you certain you're running on GPU here? Make sure you have not set:
jax.config.update('jax_platform_name', 'cpu')
Now Inference time is 80x larger!
There shouldn't be any difference to inference time - are you certain you're running on GPU here? Make sure you have not set:
jax.config.update('jax_platform_name', 'cpu')
Yes, I kill the window after saving the flax model and afterwards I don't move weights to CPU anymore. But it is so slow. Have you tested this approach @sanchit-gandhi ?
Have you tested this approach @sanchit-gandhi ?
Extensively! See my results for A100 (PyTorch) vs pmap (TPU v4-8 + JAX):
Could you perhaps share your code @hannan72? There shouldn't be any performance difference between using / not using _do_init
.
It could also be that we're recompiling each time - would be great to see your code here @hannan72 to verify!
It could also be that we're recompiling each time - would be great to see your code here @hannan72 to verify!
This is my full code:
Firstly, PyTorch model is loaded and converted to Flax an then saved:
import jax
import jax.numpy as jnp
import torch
from transformers import FlaxWhisperForConditionalGeneration, WhisperForConditionalGeneration, WhisperProcessor
pt_model_path = "/client_notebook/whisper_model_chkp"
model_id = "/client_notebook/flax_whisper_model"
jax.config.update('jax_platform_name', 'cpu')
with torch.no_grad():
model = FlaxWhisperForConditionalGeneration.from_pretrained(pt_model_path, dtype=jnp.float16, from_pt=True)
model.save_pretrained(model_id)
For deploying the Flax model, following code is used:
import jax
import jax.numpy as jnp
import torch
import flax
from scipy.io import wavfile
import time
from transformers import FlaxWhisperForConditionalGeneration, WhisperForConditionalGeneration, WhisperProcessor
model_id = "/client_notebook/flax_whisper_model"
processor = WhisperProcessor.from_pretrained(model_id)
with torch.no_grad():
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, _do_init=False)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])
audio_file_path = "sample_audio_5s.wav"
samplerate, data_waveform = wavfile.read(audio_file_path)
ata_waveform = (data_waveform)/32768.0
input_features = processor(data_waveform, padding="max_length", sampling_rate=16000, return_tensors="pt").input_features
runtime=[]
for i in range(5):
start_time = time.time()
input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, params=params, max_length=128, language='<|de|>', task ="transcribe")
runtime.append(time.time() - start_time)
print("Inference time:\n", runtime)
And the output is as follows:
Inference time:
[70.23309993743896, 14.300963640213013, 12.430477142333984, 13.643242120742798, 12.125237703323364]
GPU memory utilization: 31,127 MB GPU Type: 1x A100-40GB model checkpoint: whisper_medium
- Note: GPU memory utilization when the model is directly imported from pt model (By passing
from_pt=True
) is 31,587MB. It is just 460MB larger. But this value (460MB) is exactly the same GPU memory utilization when I put the model to cpu by runningjax.config.update('jax_platform_name', 'cpu')
during the saving of Flax model.
@sanchit-gandhi
Hey @hannan72 - thanks for the super detailed report and attaching your code. This is indeed a very strange phenomenon that we're seeing with such high memory utilisation for the Flax model. Based on what you've said, I think all of this is coming from when we load the model, rather than from when we do the forward pass.
I also ran a few tests on an A100, where I was comfortably able to fit a batch size of 16 on a 40GB device. If we're getting 31GB memory in loading, there's no way that's persistent for then the forward pass, otherwise a batch size of 16 wouldn't be possible.
I wonder whether we can trick JAX into using the CPU for the heavy weight loading, and then move the weights onto the GPU for the forward pass? Something along the lines of:
import jax
import jax.numpy as jnp
from transformers import FlaxWhisperForConditionalGeneration, WhisperForConditionalGeneration, WhisperProcessor
model_id = "/client_notebook/flax_whisper_model"
processor = WhisperProcessor.from_pretrained(model_id)
# load weights on CPU
jax.config.update('jax_platform_name', 'cpu')
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, _do_init=False)
# now move weights to GPU
jax.config.update('jax_platform_name', 'gpu')
params = jax.device_put(params, 'gpu')
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])
...
This could be a workaround, but not a fix to the high memory usage we're seeing during initialisation
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Has anyone solved this? I'm new to JAX/FLAX, so not ideas why it's taking so much memory. Though I'm quite happy with the speed