accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

5x faster text generation on multi-GPU setups (+ lower VRAM consumption)

Open emvw7yf opened this issue 2 years ago • 17 comments

TL;DR: the patch below makes multi-GPU inference 5x faster.

I noticed that text-generation is significantly slower on multi-GPU vs. single-GPU. Some results (using llama models and utilizing the full 2048 context window, I also tested with GPT-J and the results are similar):

model size quantization GPU configuration seconds/token total peak VRAM usage
7B 16 1x4090 0.025 14996 MiB
7B 16 2x4090 0.22 (8x slower) 16484 MiB
7B w/patch 16 2x4090 0.026 15891 MiB
13B 8 w/bnb 1x4090 0.067 16076 MiB
13B 8 w/bnb 2x4090 0.34 (5x slower) 18087 MiB
13B w/patch 8 w/bnb 2x4090 0.070 16883 MiB
13B 4 w/gptq 1x4090 0.036 10219 MB
13B 4 w/gptq 2x4090 0.228 (6x slower) 11091 MB
13B w/patch 4 w/gptq 2x4090 0.038 10336 MB

This makes using large models prohibitively slow: running 65b model on 4x3090 results in more than 4 seconds/token, which is quite unusable for interactive applications.

After some profiling and debugging, I narrowed this down to a simple problem:

  1. accelerate wraps the LlamaModel and moves the HUGE past_key_values tensor to the execution_device of LlamaModel (which is GPU 0) at the beginning of the inference
  2. it also wraps each LlamaDecoderLayer and move the same past_key_values (which stays constant during the entire inference pass) from GPU 0 to the execution device for each layer — repeatedly moving it between GPUs for every layer that is not on GPU 0.

This unnecessary repeated moving consumes up to 85% of the inference time. Furthermore, because it makes a copy of the past_key_values on each GPU, it significantly increases VRAM usage (although I didn't measure the exact number).

I'm not very familiar with the accelerate code base to fix the root cause, but here's a simple patch that solves this problem. It keeps past_key_values sharded across GPUs — so it is never moved between GPUs (saving execution time) and it's VRAM usage is split across GPUs (saving VRAM).

  1. save this as llama_accelerate_path.py:

from accelerate.hooks import ModelHook, AlignDevicesHook, add_hook_to_module
from accelerate.utils import find_device, send_to_device
from typing import Mapping


def send_to_device_except(data, device, non_blocking=False, skip_keys=()):
    if isinstance(data, Mapping):
        return type(data)({
                k: v if k in skip_keys else send_to_device(v, device, non_blocking)
                for k, v in data.items()
        })
    else:
        return send_to_device(data, self.input_device, non_blocking)


class AlignLogitsHook(AlignDevicesHook):
    def pre_forward(self, module, *args, **kwargs):
        if self.io_same_device:
            self.input_device = find_device([args, kwargs])

        return (
            send_to_device(args, self.execution_device),
            send_to_device_except(kwargs, self.execution_device, skip_keys=("past_key_values",)),
        )

    def post_forward(self, module, output):
        if self.io_same_device and self.input_device is not None:
            output = send_to_device_except(output, self.input_device, skip_keys=("past_key_values",))
        return output


def apply_to_model(model):
    hook = AlignLogitsHook(execution_device=model._hf_hook.execution_device, io_same_device=True)
    add_hook_to_module(model, hook)
  1. Add this to the model loading code:
  # model loading with device_map="auto"
  # model = AutoModelForCausalLM.from_pretrained(...., device_map="auto").eval()

  # apply the patch (this works for llama models only):
  import llama_accelerate_path
  model = llama_accelerate_path.apply_to_model(model)

NOTE: there is another redundancy: attention_mask and some other tensors are also copied from GPU 0 to GPU n for every layer's execution. It can be solved in a similar way, but it's less of a problem because attention_mask is much smaller than past_key_values.

It would be great if a more universal version of this fix could be merged to the accelerate mainline.

emvw7yf avatar May 06 '23 03:05 emvw7yf

I implemented this patch and can confirm that it speeds up generation by a factor of 5. Memory usage seems to increase, though.

Before applying the patch I could run 1700 context length, now I OOM at 1500.

Great work, though!

Dhaladom avatar May 06 '23 11:05 Dhaladom

Thanks a lot for the analysis and your fix suggestion. The idea of having a skip_keys argument in dispatch_model could definitely live in Accelerate and Transformers could then set it to the past_key_values when appropriate.

sgugger avatar May 06 '23 15:05 sgugger

I run some more benchmarks and confirmed that the patch does reduce VRAM usage a little (see the updated table above), but it also changes how the VRAM usage is distributes across GPUs.

@Dhaladom : I was able to run 65b int4 on 2x4090 with up to 2019 tokens (oh that is truly killing the perfectionist in me...). I used the following device_map, and I also had to kill all the GUI on the machine (log out and run sudo service gdm3 stop on ubuntu - or just use internal or cheap 3rd GPU for the GUI).

device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 0, 'model.layers.17': 0, 'model.layers.18': 0, 'model.layers.19': 0, 'model.layers.20': 0, 'model.layers.21': 0, 'model.layers.22': 0, 'model.layers.23': 0, 'model.layers.24': 0, 'model.layers.25': 0, 'model.layers.26': 0, 'model.layers.27': 0, 'model.layers.28': 0, 'model.layers.29': 0, 'model.layers.30': 0, 'model.layers.31': 0, 'model.layers.32': 0, 'model.layers.33': 0, 'model.layers.34': 0, 'model.layers.35': 0, 'model.layers.36': 0, 'model.layers.37': 0, 'model.layers.38': 0, 'model.layers.39': 0, 'model.layers.40': 1, 'model.layers.41': 1, 'model.layers.42': 1, 'model.layers.43': 1, 'model.layers.44': 1, 'model.layers.45': 1, 'model.layers.46': 1, 'model.layers.47': 1, 'model.layers.48': 1, 'model.layers.49': 1, 'model.layers.50': 1, 'model.layers.51': 1, 'model.layers.52': 1, 'model.layers.53': 1, 'model.layers.54': 1, 'model.layers.55': 1, 'model.layers.56': 1, 'model.layers.57': 1, 'model.layers.58': 1, 'model.layers.59': 1, 'model.layers.60': 1, 'model.layers.61': 1, 'model.layers.62': 1, 'model.layers.63': 1, 'model.layers.64': 1, 'model.layers.65': 1, 'model.layers.66': 1, 'model.layers.67': 1, 'model.layers.68': 1, 'model.layers.69': 1, 'model.layers.70': 1, 'model.layers.71': 1, 'model.layers.72': 1, 'model.layers.73': 1, 'model.layers.74': 1, 'model.layers.75': 1, 'model.layers.76': 1, 'model.layers.77': 1, 'model.layers.78': 1, 'model.layers.79': 1, 'model.norm': 1, 'lm_head': 1}

emvw7yf avatar May 06 '23 20:05 emvw7yf

@emvw7yf: If there was no perfectionism, I would already generate useful output with those models instead of optimizing further. 😄

The device map is surprising as on my system, the first card ("cuda:0") shows no VRAM activity at all after its initial loading. Accordingly, I mapped as many tensors as possible on this one and then the second one goes OOM.

The discrepancies may be related to triton. I don't understand what's really going on there yet, but I'm running auto-tune before inference, and the auto-tune algorithm does only bench the first GPU. It may be related to that.

Other than that, I'm using a Linux headless server already.

Anyway, thanks for taking the trouble and testing it out.

Dhaladom avatar May 06 '23 21:05 Dhaladom

i've noticed this performance drop too on my 2x 3090 setup. thank you for the fix! will this patch be implemented in the main branch eventually?

another question, maybe stupid: even with the patch, it seems, you get the same speed in terms of token generation for 2xGPU as for 1xGPU (table above). Why is this not twice as fast? The two GPUs should be calculating in parallel, right?

And another one: why does quantization not make inference faster? intuitively, less bits per parameter, faster calculation.

thanks for your help!

g588928812 avatar May 18 '23 18:05 g588928812

And another one: why does quantization not make inference faster? intuitively, less bits per parameter, faster calculation.

I can't speak from experience of this PR specifically yet (I just learned about it and plan to try it soon), and I'm not yet super experienced with dual-GPU inference. But in general I can tell you that based on my testing of GPTQ 4-bit models, a quantised model definitely does perform much better than a float16 model.. unless you are CPU bottlenecked.

Recently I did a comprehensive set of GPTQ performance benchmarks using AutoGPTQ, which I posted in the AutoGPTQ repo. I was testing with a 4090 on a Runpod server which had an AMD Epyc 24-core CPU - a CPU with high core count, but low single-core performance. Not that I really thought about the CPU at first.

I got a peak of 28 tokens/s with 4bit GPTQ and that seemed reasonable from what I'd come to expect in prior testing, and from other people in the community. Testing with HF float16 I got 23 tokens/s. Not as big a performance difference as I expected, but 4bit was still faster. I did notice that I never saw more than 30-45% GPU usage.

Then someone asked me why my results were so slow.

Long story short: it turned out that all my benchmarks were CPU bottlenecked. When you do inference with transformers/pytorch, only a single CPU core is used. And it seems that with the faster GPUs, the CPU can absolutely become a bottleneck. Hence the low GPU usage %.

I searched around for a better CPU, and found a Runpod container with an i9-13900K, which is the current leader on single-core benchmarks.

With that, still using a 4090, I got 50 tokens/s on float16 and 98 tokens/s on 4bit GPTQ! You can see those results (and my amazement), here: https://github.com/PanQiWei/AutoGPTQ/issues/49#issuecomment-1539094460

So TLDR: if you're getting the same performance from fp16 and 4bit, check to see if you're pegged at 100% of one core, and are well below 100% GPU usage.

And if anyone knows any tricks or methods to avoid this single-core bottleneck, I would love to know. I've since experimented with transformers' pipeline using batch_size greater than 1, and this does enable using the full GPU, even with a weak CPU. However that doesn't help in single-prompt scenarios, and also has some complexities to deal with (eg when the prompts to be queried in a batch are all varying lengths.)

TheBloke avatar May 18 '23 22:05 TheBloke

i've noticed this performance drop too on my 2x 3090 setup. thank you for the fix! will this patch be implemented in the main branch eventually?

another question, maybe stupid: even with the patch, it seems, you get the same speed in terms of token generation for 2xGPU as for 1xGPU (table above). Why is this not twice as fast? The two GPUs should be calculating in parallel, right?

The "interactive chat" kind of work-load is almost entirely memory-bound (less so on a 3090 than a 4090 but still). The compute hardware mostly just waits for weights to load into the SM caches. So more GPUs is actually worse for performance. You have more FLOPs sure, but that's just more idle hardware. And now you added the time to transfer state between the two GPUs over an external bus.

(On the other hand if you can batch inference work or run with longer sequences you get to reuse the weights for multiple inputs.)

And another one: why does quantization not make inference faster? intuitively, less bits per parameter, faster calculation.

It probably can once it's optimised all the way. To some extent quantization exactly counters the problem above. You have to spend some compute unpacking the values, but compute you've got plenty of in the single token scenario. I don't have directly comparable numbers to the OPs measurements but with this change I got over 112 token/s on 4bit 7b w/ 3 batches on 1x4090.

aljungberg avatar May 18 '23 22:05 aljungberg

What is the status of this pr? Is is ready to merge? Are there any concerns that need to be resolved?

yhyu13 avatar May 23 '23 12:05 yhyu13

Long story short: it turned out that all my benchmarks were CPU bottlenecked. When you do inference with transformers/pytorch, only a single CPU core is used. And it seems that with the faster GPUs, the CPU can absolutely become a bottleneck. Hence the low GPU usage %.

it's impressive what you did. as impressive is that in these times torch only uses a single core for inference. have you, by any change, found a way to fix this? i don't want to buy a new CPU (and motherboard) 😆

g588928812 avatar May 25 '23 15:05 g588928812

@emvw7yf I started to draft something in the PR linked above. I am not seeing your speed-ups on two GPUs, but I have an nvlink so it might be why the slowdown is not as drastic. Could you try on the branch of the PR and see if adding

model._hf_hook.skip_keys = "past_key_values"

gives the same speed-ups as your patches? It should do the same thing.

When it's confirmed this works, we can get more support directly baked in Transformers.

sgugger avatar May 30 '23 20:05 sgugger

gives the same speed-ups as your patches? It should do the same thing.

Works! ~5x faster (14t/s vs 2.8t/s in my own benchmark on 2x RTX 3090 and OA 30B model)

g588928812 avatar May 31 '23 03:05 g588928812

Ok, support should be coming natively in Transformers with the PR above then.

sgugger avatar May 31 '23 18:05 sgugger

With latest Accelerate as well? You can check if model._hf_hook.skip_keys errors or shows "past_key_values".

sgugger avatar May 31 '23 20:05 sgugger

Sorry, deleted my initial post because I saw that you merged a PR in accelerate too -> updated accelerate -> same. Still slow.

model._hf_hook.skip_keys is None

setting it manually fixes it, with the line below: 5x speed model._hf_hook.skip_keys = "past_key_values"

g588928812 avatar May 31 '23 20:05 g588928812

I can't reproduce. On main of Transformers and Accelerate,

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto", torch_dtype=torch.float16)
print(model._hf_hook.skip_keys)

gets me "past_key_values". What's your repro?

sgugger avatar Jun 01 '23 13:06 sgugger

gets me "past_key_values". What's your repro?

yes, same here with this code.

sorry for the confusion, for whatever reason model._hf_hook.skip_keys is None when I use GPTQ-for-LLaMa

thanks for the quick fix!

g588928812 avatar Jun 01 '23 18:06 g588928812

What model architecture is GPTQ-for-LlaMa? If it's custom code, it needs to implement _split_key_device_placement like here to work out of the box :-)

sgugger avatar Jun 01 '23 20:06 sgugger

So how might this be implemented for oggabooga's web text gen? Explain to me like I'm 5.

RedDragonGecko avatar Jun 10 '23 00:06 RedDragonGecko

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 04 '23 15:07 github-actions[bot]