transformers
transformers copied to clipboard
Performance Regression from commit 7dcd870
System Info
transformersversion: 4.28.0.dev0 (656e869a4523f6a0ce90b3aacbb05cc8fb5794bb)- Platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.35
- Python version: 3.10.10
- Huggingface_hub version: 0.13.4
- Safetensors version: 0.3.0
- PyTorch version (GPU?): 2.0.0 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: True
- Using distributed or parallel set-up in script?: False
Who can help?
@ArthurZucker @younesbelkada
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
I have a benchmark script which benchmarks the generation speed of different LLaMA models. Before commit 7dcd870 my generation speed averaged around 48 tokens/s in ideal cases, RTX 3090. After that commit the average speed is 43 tokens/s.
The specific issue seems to be the change to apply_rotary_pos_emb. My guess is the change from a rather simple slicing of two Tensors to a scatter-gather.
To test my theory I patched apply_rotary_pos_emb to its pre 7dcd870 state, and minimally modified LlamaAttention accordingly. No other modifications. Speed jumped back to 48 tokens/s.
The problem should apply generally, but the specific script I'm using is: https://github.com/fpgaminer/GPTQ-triton/blob/99ec4a3adb7fad9de33ff026bbfb64cbb3bab2f8/benchmark_generate.py
Expected behavior
I would not expect a 10% drop in performance.
cc @gante and @ArthurZucker
@fpgaminer commit 7dcd870 fixes generation when there is padding in the input (which is almost always the case for batch_size>1). It's natural that it introduces slowdowns, as the correct behavior implies changing to the tensor gathering you mentioned :)
We don't optimize for performance but rather for correctness. To skip this gathering while remaining correct, .generate() would need to be rewritten to dynamically squeeze padding and evict completed rows, which is something we have in our plans for the next months.
Meanwhile, is there anything else we can help you with?
That's fair, though a 10% performance hit is rather painful.
To that end, here's my attempt to optimize apply_rotary_pos_emb:
def ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos.squeeze((0, 1)) # [seq_len, dim]
sin = sin.squeeze((0, 1)) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def test_foo(B, L):
cos = torch.randn(1, 1, 2048, 128, dtype=torch.float16, device='cuda')
sin = torch.randn(1, 1, 2048, 128, dtype=torch.float16, device='cuda')
position_ids = torch.randint(0, 2048, (B, L), dtype=torch.int64, device='cuda')
q = torch.randn(B, 32, L, 128, dtype=torch.float16, device='cuda')
k = torch.randn(B, 32, L, 128, dtype=torch.float16, device='cuda')
# Verify
ref = ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids)
fast = fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids)
assert torch.equal(ref[0], fast[0])
assert torch.equal(ref[1], fast[1])
# Benchmark
ref_ms, ref_min_ms, ref_max_ms = triton.testing.do_bench(lambda: ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids))
fast_ms, fast_min_ms, fast_max_ms = triton.testing.do_bench(lambda: fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids))
speedup = ref_ms * 100 / fast_ms
print(f'{B} | {L:3d} | {ref_ms:.6f} | {fast_ms:.6f} | {speedup:.2f}%')
print('B | L | ref | fast | speedup')
for B in [1, 2, 4, 8]:
for L in [1, 2, 4, 8, 10, 100]:
test_foo(B, L)
Output:
B | L | ref | fast | speedup
1 | 1 | 0.043008 | 0.035840 | 120.00%
1 | 2 | 0.044032 | 0.036864 | 119.44%
1 | 4 | 0.047104 | 0.038912 | 121.05%
1 | 8 | 0.046080 | 0.039936 | 115.38%
1 | 10 | 0.048128 | 0.039936 | 120.51%
1 | 100 | 0.058368 | 0.052224 | 111.76%
2 | 1 | 0.047104 | 0.036864 | 127.78%
2 | 2 | 0.049152 | 0.039936 | 123.08%
2 | 4 | 0.050176 | 0.040960 | 122.50%
2 | 8 | 0.050176 | 0.041984 | 119.51%
2 | 10 | 0.050176 | 0.041984 | 119.51%
2 | 100 | 0.079872 | 0.070656 | 113.04%
4 | 1 | 0.051200 | 0.039936 | 128.21%
4 | 2 | 0.053248 | 0.040960 | 130.00%
4 | 4 | 0.054272 | 0.041984 | 129.27%
4 | 8 | 0.057344 | 0.045056 | 127.27%
4 | 10 | 0.057344 | 0.045056 | 127.27%
4 | 100 | 0.130048 | 0.119808 | 108.55%
8 | 1 | 0.057344 | 0.040960 | 140.00%
8 | 2 | 0.059392 | 0.041984 | 141.46%
8 | 4 | 0.062464 | 0.045056 | 138.64%
For reference, the pre 7dc870 function runs in 0.030ms on 1x1, so this isn't quite as fast but gets closer.
Would a pull request with this change be welcome? I've done my best to verify its correctness with the above code.
@fpgaminer that is great! Absolutely, a PR would be very welcome 🙌
(We'd be happy to integrate other optimization opportunities if you spot them, we rarely have the bandwidth to optimize our modeling code)
@fpgaminer commit 7dcd870 fixes generation when there is padding in the input (which is almost always the case for
batch_size>1). It's natural that it introduces slowdowns, as the correct behavior implies changing to the tensor gathering you mentioned :)
Maybe there's something I'm not seeing here but Llama uses rotary positional embeddings so left padding should have no effect on the result?
Sure, the intermediate result from apply_rotary_pos_emb changes if you shift all tokens left or right, but the whole point of using relative embeddings is that they're invariant to the absolute position in terms of the final attention weight. So you can shift all tokens 50 positions to the right and the attention score between pairs of tokens will be the same, modulus any rounding errors.
Or are you saying there are cases when padding is literally inserted inside of the sequence, therefore changing the relative distances between tokens, @gante?
@aljungberg I agree with everything you wrote, rotary positional embeddings should be position-invariant. In practice, the small rounding errors compound over autoregressive text generation, leading greedy decoding (which is normally invariant wrt small fluctuations) to produce different text.
With the right position index, the error becomes much smaller, and the results become more stable regardless of padding. That's why we also added it to our high-performance text generation repo, despite the difference being quite small.
Out of curiosity, this test was failing on GPTNeoX and Llama before we added this change. In theory, it shouldn't have failed at all!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.