accelerate
accelerate copied to clipboard
5x faster text generation on multi-GPU setups (+ lower VRAM consumption)
TL;DR: the patch below makes multi-GPU inference 5x faster.
I noticed that text-generation is significantly slower on multi-GPU vs. single-GPU. Some results (using llama models and utilizing the full 2048 context window, I also tested with GPT-J and the results are similar):
| model size | quantization | GPU configuration | seconds/token | total peak VRAM usage |
|---|---|---|---|---|
| 7B | 16 | 1x4090 | 0.025 | 14996 MiB |
| 7B | 16 | 2x4090 | 0.22 (8x slower) | 16484 MiB |
| 7B w/patch | 16 | 2x4090 | 0.026 | 15891 MiB |
| 13B | 8 w/bnb | 1x4090 | 0.067 | 16076 MiB |
| 13B | 8 w/bnb | 2x4090 | 0.34 (5x slower) | 18087 MiB |
| 13B w/patch | 8 w/bnb | 2x4090 | 0.070 | 16883 MiB |
| 13B | 4 w/gptq | 1x4090 | 0.036 | 10219 MB |
| 13B | 4 w/gptq | 2x4090 | 0.228 (6x slower) | 11091 MB |
| 13B w/patch | 4 w/gptq | 2x4090 | 0.038 | 10336 MB |
This makes using large models prohibitively slow: running 65b model on 4x3090 results in more than 4 seconds/token, which is quite unusable for interactive applications.
After some profiling and debugging, I narrowed this down to a simple problem:
- accelerate wraps the
LlamaModeland moves the HUGEpast_key_valuestensor to theexecution_deviceofLlamaModel(which is GPU 0) at the beginning of the inference - it also wraps each
LlamaDecoderLayerand move the samepast_key_values(which stays constant during the entire inference pass) from GPU 0 to the execution device for each layer — repeatedly moving it between GPUs for every layer that is not on GPU 0.
This unnecessary repeated moving consumes up to 85% of the inference time. Furthermore, because it makes a copy of the past_key_values on each GPU, it significantly increases VRAM usage (although I didn't measure the exact number).
I'm not very familiar with the accelerate code base to fix the root cause, but here's a simple patch that solves this problem. It keeps past_key_values sharded across GPUs — so it is never moved between GPUs (saving execution time) and it's VRAM usage is split across GPUs (saving VRAM).
- save this as llama_accelerate_path.py:
from accelerate.hooks import ModelHook, AlignDevicesHook, add_hook_to_module
from accelerate.utils import find_device, send_to_device
from typing import Mapping
def send_to_device_except(data, device, non_blocking=False, skip_keys=()):
if isinstance(data, Mapping):
return type(data)({
k: v if k in skip_keys else send_to_device(v, device, non_blocking)
for k, v in data.items()
})
else:
return send_to_device(data, self.input_device, non_blocking)
class AlignLogitsHook(AlignDevicesHook):
def pre_forward(self, module, *args, **kwargs):
if self.io_same_device:
self.input_device = find_device([args, kwargs])
return (
send_to_device(args, self.execution_device),
send_to_device_except(kwargs, self.execution_device, skip_keys=("past_key_values",)),
)
def post_forward(self, module, output):
if self.io_same_device and self.input_device is not None:
output = send_to_device_except(output, self.input_device, skip_keys=("past_key_values",))
return output
def apply_to_model(model):
hook = AlignLogitsHook(execution_device=model._hf_hook.execution_device, io_same_device=True)
add_hook_to_module(model, hook)
- Add this to the model loading code:
# model loading with device_map="auto"
# model = AutoModelForCausalLM.from_pretrained(...., device_map="auto").eval()
# apply the patch (this works for llama models only):
import llama_accelerate_path
model = llama_accelerate_path.apply_to_model(model)
NOTE: there is another redundancy: attention_mask and some other tensors are also copied from GPU 0 to GPU n for every layer's execution. It can be solved in a similar way, but it's less of a problem because attention_mask is much smaller than past_key_values.
It would be great if a more universal version of this fix could be merged to the accelerate mainline.
I implemented this patch and can confirm that it speeds up generation by a factor of 5. Memory usage seems to increase, though.
Before applying the patch I could run 1700 context length, now I OOM at 1500.
Great work, though!
Thanks a lot for the analysis and your fix suggestion. The idea of having a skip_keys argument in dispatch_model could definitely live in Accelerate and Transformers could then set it to the past_key_values when appropriate.
I run some more benchmarks and confirmed that the patch does reduce VRAM usage a little (see the updated table above), but it also changes how the VRAM usage is distributes across GPUs.
@Dhaladom : I was able to run 65b int4 on 2x4090 with up to 2019 tokens (oh that is truly killing the perfectionist in me...). I used the following device_map, and I also had to kill all the GUI on the machine (log out and run sudo service gdm3 stop on ubuntu - or just use internal or cheap 3rd GPU for the GUI).
device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 0, 'model.layers.17': 0, 'model.layers.18': 0, 'model.layers.19': 0, 'model.layers.20': 0, 'model.layers.21': 0, 'model.layers.22': 0, 'model.layers.23': 0, 'model.layers.24': 0, 'model.layers.25': 0, 'model.layers.26': 0, 'model.layers.27': 0, 'model.layers.28': 0, 'model.layers.29': 0, 'model.layers.30': 0, 'model.layers.31': 0, 'model.layers.32': 0, 'model.layers.33': 0, 'model.layers.34': 0, 'model.layers.35': 0, 'model.layers.36': 0, 'model.layers.37': 0, 'model.layers.38': 0, 'model.layers.39': 0, 'model.layers.40': 1, 'model.layers.41': 1, 'model.layers.42': 1, 'model.layers.43': 1, 'model.layers.44': 1, 'model.layers.45': 1, 'model.layers.46': 1, 'model.layers.47': 1, 'model.layers.48': 1, 'model.layers.49': 1, 'model.layers.50': 1, 'model.layers.51': 1, 'model.layers.52': 1, 'model.layers.53': 1, 'model.layers.54': 1, 'model.layers.55': 1, 'model.layers.56': 1, 'model.layers.57': 1, 'model.layers.58': 1, 'model.layers.59': 1, 'model.layers.60': 1, 'model.layers.61': 1, 'model.layers.62': 1, 'model.layers.63': 1, 'model.layers.64': 1, 'model.layers.65': 1, 'model.layers.66': 1, 'model.layers.67': 1, 'model.layers.68': 1, 'model.layers.69': 1, 'model.layers.70': 1, 'model.layers.71': 1, 'model.layers.72': 1, 'model.layers.73': 1, 'model.layers.74': 1, 'model.layers.75': 1, 'model.layers.76': 1, 'model.layers.77': 1, 'model.layers.78': 1, 'model.layers.79': 1, 'model.norm': 1, 'lm_head': 1}
@emvw7yf: If there was no perfectionism, I would already generate useful output with those models instead of optimizing further. 😄
The device map is surprising as on my system, the first card ("cuda:0") shows no VRAM activity at all after its initial loading. Accordingly, I mapped as many tensors as possible on this one and then the second one goes OOM.
The discrepancies may be related to triton. I don't understand what's really going on there yet, but I'm running auto-tune before inference, and the auto-tune algorithm does only bench the first GPU. It may be related to that.
Other than that, I'm using a Linux headless server already.
Anyway, thanks for taking the trouble and testing it out.
i've noticed this performance drop too on my 2x 3090 setup. thank you for the fix! will this patch be implemented in the main branch eventually?
another question, maybe stupid: even with the patch, it seems, you get the same speed in terms of token generation for 2xGPU as for 1xGPU (table above). Why is this not twice as fast? The two GPUs should be calculating in parallel, right?
And another one: why does quantization not make inference faster? intuitively, less bits per parameter, faster calculation.
thanks for your help!
And another one: why does quantization not make inference faster? intuitively, less bits per parameter, faster calculation.
I can't speak from experience of this PR specifically yet (I just learned about it and plan to try it soon), and I'm not yet super experienced with dual-GPU inference. But in general I can tell you that based on my testing of GPTQ 4-bit models, a quantised model definitely does perform much better than a float16 model.. unless you are CPU bottlenecked.
Recently I did a comprehensive set of GPTQ performance benchmarks using AutoGPTQ, which I posted in the AutoGPTQ repo. I was testing with a 4090 on a Runpod server which had an AMD Epyc 24-core CPU - a CPU with high core count, but low single-core performance. Not that I really thought about the CPU at first.
I got a peak of 28 tokens/s with 4bit GPTQ and that seemed reasonable from what I'd come to expect in prior testing, and from other people in the community. Testing with HF float16 I got 23 tokens/s. Not as big a performance difference as I expected, but 4bit was still faster. I did notice that I never saw more than 30-45% GPU usage.
Then someone asked me why my results were so slow.
Long story short: it turned out that all my benchmarks were CPU bottlenecked. When you do inference with transformers/pytorch, only a single CPU core is used. And it seems that with the faster GPUs, the CPU can absolutely become a bottleneck. Hence the low GPU usage %.
I searched around for a better CPU, and found a Runpod container with an i9-13900K, which is the current leader on single-core benchmarks.
With that, still using a 4090, I got 50 tokens/s on float16 and 98 tokens/s on 4bit GPTQ! You can see those results (and my amazement), here: https://github.com/PanQiWei/AutoGPTQ/issues/49#issuecomment-1539094460
So TLDR: if you're getting the same performance from fp16 and 4bit, check to see if you're pegged at 100% of one core, and are well below 100% GPU usage.
And if anyone knows any tricks or methods to avoid this single-core bottleneck, I would love to know. I've since experimented with transformers' pipeline using batch_size greater than 1, and this does enable using the full GPU, even with a weak CPU. However that doesn't help in single-prompt scenarios, and also has some complexities to deal with (eg when the prompts to be queried in a batch are all varying lengths.)
i've noticed this performance drop too on my 2x 3090 setup. thank you for the fix! will this patch be implemented in the main branch eventually?
another question, maybe stupid: even with the patch, it seems, you get the same speed in terms of token generation for 2xGPU as for 1xGPU (table above). Why is this not twice as fast? The two GPUs should be calculating in parallel, right?
The "interactive chat" kind of work-load is almost entirely memory-bound (less so on a 3090 than a 4090 but still). The compute hardware mostly just waits for weights to load into the SM caches. So more GPUs is actually worse for performance. You have more FLOPs sure, but that's just more idle hardware. And now you added the time to transfer state between the two GPUs over an external bus.
(On the other hand if you can batch inference work or run with longer sequences you get to reuse the weights for multiple inputs.)
And another one: why does quantization not make inference faster? intuitively, less bits per parameter, faster calculation.
It probably can once it's optimised all the way. To some extent quantization exactly counters the problem above. You have to spend some compute unpacking the values, but compute you've got plenty of in the single token scenario. I don't have directly comparable numbers to the OPs measurements but with this change I got over 112 token/s on 4bit 7b w/ 3 batches on 1x4090.
What is the status of this pr? Is is ready to merge? Are there any concerns that need to be resolved?
Long story short: it turned out that all my benchmarks were CPU bottlenecked. When you do inference with transformers/pytorch, only a single CPU core is used. And it seems that with the faster GPUs, the CPU can absolutely become a bottleneck. Hence the low GPU usage %.
it's impressive what you did. as impressive is that in these times torch only uses a single core for inference. have you, by any change, found a way to fix this? i don't want to buy a new CPU (and motherboard) 😆
@emvw7yf I started to draft something in the PR linked above. I am not seeing your speed-ups on two GPUs, but I have an nvlink so it might be why the slowdown is not as drastic. Could you try on the branch of the PR and see if adding
model._hf_hook.skip_keys = "past_key_values"
gives the same speed-ups as your patches? It should do the same thing.
When it's confirmed this works, we can get more support directly baked in Transformers.
gives the same speed-ups as your patches? It should do the same thing.
Works! ~5x faster (14t/s vs 2.8t/s in my own benchmark on 2x RTX 3090 and OA 30B model)
Ok, support should be coming natively in Transformers with the PR above then.
With latest Accelerate as well? You can check if model._hf_hook.skip_keys errors or shows "past_key_values".
Sorry, deleted my initial post because I saw that you merged a PR in accelerate too -> updated accelerate -> same. Still slow.
model._hf_hook.skip_keys is None
setting it manually fixes it, with the line below: 5x speed
model._hf_hook.skip_keys = "past_key_values"
I can't reproduce. On main of Transformers and Accelerate,
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto", torch_dtype=torch.float16)
print(model._hf_hook.skip_keys)
gets me "past_key_values". What's your repro?
gets me
"past_key_values". What's your repro?
yes, same here with this code.
sorry for the confusion, for whatever reason model._hf_hook.skip_keys is None when I use GPTQ-for-LLaMa
thanks for the quick fix!
What model architecture is GPTQ-for-LlaMa? If it's custom code, it needs to implement _split_key_device_placement like here to work out of the box :-)
So how might this be implemented for oggabooga's web text gen? Explain to me like I'm 5.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.