transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Flax Whisper uses a lot of GPU memory

Open hannan72 opened this issue 1 year ago • 8 comments

I'm using Flax whisper-medium and now it's ~3x faster rather than the pytorch deployment. but now it is allocating ~10x more GPU memory. loading Pytorch model takes ~3GB, but loading Flax Whisper-medium takes >30GB of VRAM. Does this huge memory allocation normal? And is there any prepared method for cut it down? @andyehrenberg @ArthurZucker @sanchit-gandhi

The code for loading Flax model:

with torch.no_grad():
    model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, from_pt=True)
    jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language"])

hannan72 avatar Mar 17 '23 09:03 hannan72

Hey @hannan72! Could you try disabling _do_init? This way we won't initialise a random version of the parameters. Note that this isn't compatible with from_pt=True, so you'll have to load a checkpoint where the Flax weights have already been saved:

model, params = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, _do_init=False)

jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language"])

input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, params=params, max_length=128, language='<|en|>')  # we need to explicitly pass the params now since we're in Flax's stateless design

If you need to load a model where you only have PyTorch weights, you can first convert them to Flax on CPU:

import jax

# Global flag to set a specific platform, must be used at startup. ONLY DO THIS FOR SAVING WEIGHTS ON CPU!
jax.config.update('jax_platform_name', 'cpu')

model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, from_pt=True)
model.save_pretrained("save/path/to/ckpt/here")

Kill this window, and then open up a new one and load:

model, params = FlaxWhisperForConditionalGeneration.from_pretrained("save/path/to/ckpt/here", dtype=jnp.float16, _do_init=False)

sanchit-gandhi avatar Mar 17 '23 16:03 sanchit-gandhi

Thanks for your response @sanchit-gandhi I've tested your proposed approach, save flax model by converting to cpu and then restart kernel and load FlaxWhisperForConditionalGeneration by try disabling _do_init. But Inference time increased a lot while GPU memory utilization didn't decreased significantly.

results when use from_pt=True for whisper-medium on a 10 second audio on A100-40GB GPU:

  • GPU memory usage: ~33.1GB
  • Inference time: ~0.22 seconds

results when use _do_init=False for flax saved whisper-medium on a 10 second audio on A100-40GB GPU:

  • GPU memory usage: ~31.1GB
  • Inference time: ~16.5 seconds

Now Inference time is 80x larger!

hannan72 avatar Mar 19 '23 20:03 hannan72

Some of the extra GPU memory can probably be attributed to how the flax generation implements the kv cache. Check what happens when you set max new tokens to be smaller.

andyehrenberg avatar Mar 19 '23 21:03 andyehrenberg

Also, it doesn't make sense to run the flax stuff within a torch.no_grad() context.

andyehrenberg avatar Mar 19 '23 21:03 andyehrenberg

I also found that whisper_small checkpoint is also taking ~33GB of GPU RAM!

hannan72 avatar Mar 19 '23 22:03 hannan72

For my fine-tuned whisper-medium, if I don't run inside the torch.no_grad(), I get an error and it is just fixed by adding torch.no_grad():

RuntimeError                              Traceback (most recent call last)
/s2t-test/client_notebook/Untitled1.ipynb Cell 25 in <cell line: 3>()
      1 jax.config.update('jax_platform_name', 'cpu')
----> 2 model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id , dtype=jnp.float16, from_pt=True)
      3 model.save_pretrained(model_id+ "/flax/")

File /opt/conda/lib/python3.8/site-packages/transformers/modeling_flax_utils.py:810, in FlaxPreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
    807 model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
    809 if from_pt:
--> 810     state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
    811 else:
    812     if is_sharded:

File /opt/conda/lib/python3.8/site-packages/transformers/modeling_flax_pytorch_utils.py:62, in load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys)
     59     pt_state_dict = torch.load(pt_path, map_location="cpu")
     60     logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
---> 62     flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
     63 else:
     64     # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
     65     flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)

File /opt/conda/lib/python3.8/site-packages/transformers/modeling_flax_pytorch_utils.py:128, in convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
    126 def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
...
--> 128     pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
    130     model_prefix = flax_model.base_model_prefix
    132     # use params dict if the model contains batch norm layers

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

(However pretrained models does not need to be loaded inside torch.no_grad() )

Albeit the results I mentioned after @sanchit-gandhi 's answer, was test with and without torch.no_grad() and it didn't make any change.

hannan72 avatar Mar 19 '23 22:03 hannan72

Now Inference time is 80x larger!

There shouldn't be any difference to inference time - are you certain you're running on GPU here? Make sure you have not set:

jax.config.update('jax_platform_name', 'cpu')

sanchit-gandhi avatar Mar 22 '23 10:03 sanchit-gandhi

Now Inference time is 80x larger!

There shouldn't be any difference to inference time - are you certain you're running on GPU here? Make sure you have not set:

jax.config.update('jax_platform_name', 'cpu')

Yes, I kill the window after saving the flax model and afterwards I don't move weights to CPU anymore. But it is so slow. Have you tested this approach @sanchit-gandhi ?

hannan72 avatar Mar 25 '23 10:03 hannan72

Have you tested this approach @sanchit-gandhi ?

Extensively! See my results for A100 (PyTorch) vs pmap (TPU v4-8 + JAX):

Screenshot 2023-04-03 at 11 54 17

Could you perhaps share your code @hannan72? There shouldn't be any performance difference between using / not using _do_init.

sanchit-gandhi avatar Apr 03 '23 10:04 sanchit-gandhi

It could also be that we're recompiling each time - would be great to see your code here @hannan72 to verify!

sanchit-gandhi avatar Apr 04 '23 16:04 sanchit-gandhi

It could also be that we're recompiling each time - would be great to see your code here @hannan72 to verify!

This is my full code:

Firstly, PyTorch model is loaded and converted to Flax an then saved:

import jax
import jax.numpy as jnp
import torch
from transformers import FlaxWhisperForConditionalGeneration, WhisperForConditionalGeneration, WhisperProcessor

pt_model_path = "/client_notebook/whisper_model_chkp"
model_id = "/client_notebook/flax_whisper_model"

jax.config.update('jax_platform_name', 'cpu')
with torch.no_grad():
    model = FlaxWhisperForConditionalGeneration.from_pretrained(pt_model_path, dtype=jnp.float16, from_pt=True)
    model.save_pretrained(model_id)

For deploying the Flax model, following code is used:

import jax
import jax.numpy as jnp
import torch
import flax
from scipy.io import wavfile
import time
from transformers import FlaxWhisperForConditionalGeneration, WhisperForConditionalGeneration, WhisperProcessor

model_id = "/client_notebook/flax_whisper_model"
processor = WhisperProcessor.from_pretrained(model_id)

with torch.no_grad():
    model, params = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, _do_init=False)
    jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])

audio_file_path = "sample_audio_5s.wav"
samplerate, data_waveform = wavfile.read(audio_file_path)
ata_waveform = (data_waveform)/32768.0
input_features = processor(data_waveform, padding="max_length", sampling_rate=16000, return_tensors="pt").input_features

runtime=[]
for i in range(5):
    start_time = time.time()
    input_features = jnp.array(input_features, dtype=jnp.float16)
    pred_ids = jit_generate(input_features, params=params, max_length=128, language='<|de|>', task ="transcribe")
    runtime.append(time.time() - start_time)
print("Inference time:\n", runtime)

And the output is as follows:

Inference time: 
[70.23309993743896, 14.300963640213013, 12.430477142333984, 13.643242120742798, 12.125237703323364]

GPU memory utilization: 31,127 MB GPU Type: 1x A100-40GB model checkpoint: whisper_medium

  • Note: GPU memory utilization when the model is directly imported from pt model (By passing from_pt=True) is 31,587MB. It is just 460MB larger. But this value (460MB) is exactly the same GPU memory utilization when I put the model to cpu by running jax.config.update('jax_platform_name', 'cpu') during the saving of Flax model.

@sanchit-gandhi

hannan72 avatar Apr 11 '23 11:04 hannan72

Hey @hannan72 - thanks for the super detailed report and attaching your code. This is indeed a very strange phenomenon that we're seeing with such high memory utilisation for the Flax model. Based on what you've said, I think all of this is coming from when we load the model, rather than from when we do the forward pass.

I also ran a few tests on an A100, where I was comfortably able to fit a batch size of 16 on a 40GB device. If we're getting 31GB memory in loading, there's no way that's persistent for then the forward pass, otherwise a batch size of 16 wouldn't be possible.

I wonder whether we can trick JAX into using the CPU for the heavy weight loading, and then move the weights onto the GPU for the forward pass? Something along the lines of:

import jax
import jax.numpy as jnp

from transformers import FlaxWhisperForConditionalGeneration, WhisperForConditionalGeneration, WhisperProcessor

model_id = "/client_notebook/flax_whisper_model"
processor = WhisperProcessor.from_pretrained(model_id)

# load weights on CPU
jax.config.update('jax_platform_name', 'cpu')
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, _do_init=False)

# now move weights to GPU
jax.config.update('jax_platform_name', 'gpu')
params = jax.device_put(params, 'gpu')

jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])
...

This could be a workaround, but not a fix to the high memory usage we're seeing during initialisation

sanchit-gandhi avatar Apr 19 '23 17:04 sanchit-gandhi

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 14 '23 15:05 github-actions[bot]

Has anyone solved this? I'm new to JAX/FLAX, so not ideas why it's taking so much memory. Though I'm quite happy with the speed

RomanKoshkin avatar May 01 '24 10:05 RomanKoshkin