xformers
xformers copied to clipboard
Significant performance drops when using fast memory efficient attention
🐛 Bug
I am currently experimenting with different scaled dot product attention implementations to evaluate training speed and GPU memory consumption.
I compared all methods running the following train.py
from Lucidrains x-transformers
library https://github.com/lucidrains/x-transformers/blob/main/examples/enwik8_simple/train.py.
In order to compare the methods I altered the attention implementation in https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py. Concretely I commented out lines 758 - 822 and dropped in the following implementations: The following causal mask is equally build for all methods.
# building causal mask -> preceeds all of the following implementations
# b (batch size), h (head size), i (query len), j (key len)
attn_bias = torch.ones((b, h, i, j), dtype=torch.bool, device=device).triu(j - i + 1)
attn_bias = torch.zeros_like(attn_bias ).masked_fill(attn_bias , float("-inf"))
-
xformers
memory efficient implementation:
import xformers.ops as xops
# transposing is done due to the different shape order (B, N, H, D)
out = xops.memory_efficient_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_bias=attn_bias)
out = out.transpose(1, 2)
-
pytorch 2.0
implementation in math mode (not memory efficient)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
-
pytorch 2.0
implementation in memory efficient mode (this seems to usexformers
implementation)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
-
lucidrain
pytorch implementation of memory efficient attention: https://github.com/lucidrains/memory-efficient-attention-pytorch
from memory_efficient_attention_pytorch import memory_efficient_attention
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, q_bucket_size=1024, k_bucket_size=2048)
The configured constants in the above linked training script are:
NUM_BATCHES = 200 # int(1e5)
BATCH_SIZE = 2
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 2048
and the model is always initialized as:
model = TransformerWrapper(num_tokens=256, max_seq_len=SEQ_LEN, attn_layers=Decoder(dim=1024, depth=6, heads=8))
Otherwise the linked training and model scripts are unchanged.
I startet training runs for all 4 configurations (using a V100 32GB GPU, see below for detailed environment info) with and without passing the actual attn_bias
to the attention function.
These are the performance (speed and memory consumption) results:
Method | GPU Memory / tqdm time per batch (no attn_bias ) |
GPU Memory / tqdm time per batch (attn_bias ) |
---|---|---|
xformer | 3518MiB, 1.31s/it | 6662MiB, 1.37s/it |
pt math | 5638MiB, 0.71s/it | 5638MiB, 0.76s/it |
pt mem | 3518MiB, 1.31s/it | error (https://github.com/pytorch/pytorch/issues/97514) |
lucid | 3656MiB, 0.84s/it | 5308MiB, 0.91s/it |
As you can see the math mode pytorch 2.0
function is almost twice as fast as the xformer
implementation. And in the case of providing a full attn_bias
the memory footprint is even smaller.
The "old" memory efficient implementation of lucidrain is considerably faster in my experiment which does not really make sense, since the xformer
implementation ishould be optimized for cuda
no?
Is the above expected behaviour or what is the reason for these results?
Environment
PyTorch version: 2.0.0+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.2 LTS (x86_64) GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0 Clang version: Could not collect CMake version: version 3.26.0 Libc version: glibc-2.27
Python version: 3.10.10 (main, Mar 21 2023, 18:45:11) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-3.10.0-1160.45.1.el7.x86_64-x86_64-with-glibc2.27 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB GPU 1: Tesla V100-SXM2-32GB
Nvidia driver version: 470.86 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 80 On-line CPU(s) list: 0-79 Thread(s) per core: 2 Core(s) per socket: 20 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz Stepping: 4 CPU MHz: 1891.113 CPU max MHz: 3700.0000 CPU min MHz: 1000.0000 BogoMIPS: 4800.00 Virtualization: VT-x L1d cache: 32K L1i cache: 32K L2 cache: 1024K L3 cache: 28160K NUMA node0 CPU(s): 0-19,40-59 NUMA node1 CPU(s): 20-39,60-79 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 cdp_l3 invpcid_single intel_pt ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke md_clear spec_ctrl intel_stibp flush_l1d
Versions of relevant libraries: [pip3] memory-efficient-attention-pytorch==0.1.2 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.2 [pip3] torch==2.0.0 [conda] memory-efficient-attention-pytorch 0.1.2 pypi_0 pypi [conda] numpy 1.24.2 pypi_0 pypi [conda] torch 2.0.0 pypi_0 pypi
Hi @LarsHill Thanks for the detailed report :) Indeed, xformer's memory-efficient attention should be faster than the lucid one. A few questions: (1) I assume you are measuring iteration time for training right? (so including the backward pass) (2) What are the shapes of your problem? (batch size, number of keys/queries, embedding size, number of heads) (3) Are you training in half-precision? And some things you could improve: (a) Please don't create the causal mask tensor yourself. You should use this one instead - this should already be much faster:
# Signals xFormers that we want causal attention, but without creating any torch.Tensor
attn_bias = xformers.ops.LowerTriangularMask()
(b) You are doing a lot of transposes which might not be needed. I see multiple rearrange calls like this one which require memory copy during forward and backward:
rearrange(q, 'b n (h d) -> b h n d', h = h)
xFormers takes inputs directly in the BNHD
format (that we call BMHK
in our codebase) so this should be faster.
Hi @danthe3rd Thanks for the very quick response!
- Indeed, it is a training iteration including backward pass, optimizer step, etc. (see the linked
train.py
script with the training loop.). In a seperate script with some random q, k, v data I only benchmarked a simply forward pass through the attention functions. There thexformers
method was significantly faster. So I would suspect the backward pass has something to do with it and the attention bias as well. - Shapes of q, k, v are
torch.Size([2, 8, 2048, 64])
in [b, h, n, d]. - I train with
torch.float32
in all cases.
(a) Using attn_bias = xformers.ops.LowerTriangularMask()
speeds it up to 0.93it/s
and reduces memory to 3518MiB
. The speed-up is especially surprising, since it is now faster than passing no mask at all... However, practically this won't be that useful, since I need to apply specific bias tensors, e.g. the alibi positional bias in my actual training runs. Also, if pooling is necessary I need to alter the bias by adding -inf
for all padded tokens. So, ideally I would like to pass a complete custom mask to not lose flexibility. The few performant mask implementations like xformers.ops.LowerTriangularMask()
unfortunately don't cover all the relevant cases...
(b) True, but in that script the input q, k, v dimension is [b, h, n, d]. So in oder to keep it aligned I had to rearrange the shape. However, even if remove the traspositions I get the same performance results, so the impact seems to be neglectable.
Btw. I'm using the pre-release of xformers
(installed via pip) that is compatible with pytorch 2.0
.
I train with torch.float32 in all cases.
xformers kernels have been specially optimized for f16 or bf16 (A100). It you can run your model with either autocast or fully f16, you will get much better performance. Flash-Attention does the tradeoff of recalculating more stuff to save memory transfer. This works well because compute has become super fast, while the memory is slower in comparison. On older GPUs (V100), the compute is still quite expensive (especially in f32), so that might be the reason why you don't have speedups on f32 (PT math will store the attention matrix for the BW pass and not recompute it)
@danthe3rd Thanks for the added information.
I added autocast and gradient scaling to the training and validation loop and tested the performance again. Unfortunately the overall picture is still the same. xformers
scaled dot product attention is still the slowest contender among all methods...
It would be sad if the cause really is the "older" V100 GPU. I mean it is not that old after all and is still listed among to top tier ML GPUs.
I turned the q, k, v format around so that all other methods need transposition but the xformers method does not (advantage xformers). Also, I tested with different masking setups and with autocast (f16) and grad scaling enabled. Here are the results:
Method | no mask | custom float tensor mask | xformer optimized causal mask |
---|---|---|---|
xformer | 3126MiB, 2.24it/s | 4554MiB, 2.01it/s | 3150MiB, 3.24it/s |
pt math | 4144MiB, 3.81it/s | 4144MiB, 3.63it/s | x |
lucid | 3486MiB, 2.19it/s | 4126MiB, 2.06it/s | x |
First of all, I think it is odd that passing an optimized xformers.ops.LowerTriangularMask()
performs better than passing no mask at all. Passing and applying no bias at all intuitively should be faster. The opposite is the case for xformers
.
Second, with fp16 (autocast) the difference to lucidrains implementation is less significant. Still, when passing a custom float mask lucidrains implementation performs slightly better.
Interestingly, when providing a custom float tensor mask to xformers
dot product, both the GPU memory consumption and the speed is worse than the standard pytorch implementation.
Overall, given one cannot use a custom xformers
mask like xformers.ops.LowerTriangularMask()
due to specific masking needs, e.g. key padding, alibi releative positional bias, etc. one is significantly better off with the standard pytorch math implementation. The memory when providing such a mask does not increase and the speed is significantly higher.
These are my current conclusions. Would be interesting to see if the results drastically change with an A100 gpu. Unfortunately I don't have access to one.
Shapes of q, k, v are torch.Size([2, 8, 2048, 64]) in [b, h, n, d].
Currently, the backward on V100 isn't well parallelised. You will get best performance if b * h > 100
. So if possible, I would increase the batch size.
First of all, I think it is odd that passing an optimized
xformers.ops.LowerTriangularMask()
performs better than passing no mask at all
It is faster because the kernel knows it can skip half of the calculations (what is masked out). If you pass a torch.Tensor
as bias, it has to compute everything and then add the bias.
Overall, given one cannot use a custom
xformers
mask likexformers.ops.LowerTriangularMask()
due to specific masking needs, e.g. key padding, alibi releative positional bias, etc
We support a few of these optimized masks. If you want to combine a mask with a causal masking, you can use LowerTriangularMaskWithTensorBias
. If you have sequences of various lengths, you can use the BlockDiagonalMask
.
We don't have anything for Alibi at the moment unfortunately...
Shapes of q, k, v are torch.Size([2, 8, 2048, 64]) in [b, h, n, d].
Currently, the backward on V100 isn't well parallelised. You will get best performance if
b * h > 100
. So if possible, I would increase the batch size.First of all, I think it is odd that passing an optimized
xformers.ops.LowerTriangularMask()
performs better than passing no mask at allIt is faster because the kernel knows it can skip half of the calculations (what is masked out). If you pass a
torch.Tensor
as bias, it has to compute everything and then add the bias.Overall, given one cannot use a custom
xformers
mask likexformers.ops.LowerTriangularMask()
due to specific masking needs, e.g. key padding, alibi releative positional bias, etcWe support a few of these optimized masks. If you want to combine a mask with a causal masking, you can use
LowerTriangularMaskWithTensorBias
. If you have sequences of various lengths, you can use theBlockDiagonalMask
. We don't have anything for Alibi at the moment unfortunately...
Just curious why supporting Alibi is difficult? I noticed that the official flash attention repo doesn't support it either.
Just curious why supporting Alibi is difficult? I noticed that the official flash attention repo doesn't support it either.
I don't think it's difficult. It's just some additional work required to make it run fast, adds something more we need to support, and also needs to add support for it in the BW pass.
@danthe3rd I also need alibi support. for now, I pass bias = LowerTriangularMaskWithTensorBias(alibi_bias)
to xops.memory_efficient_attention(..., attn_bias=bias )
. The forward only is ok, but failed at backward in training mode. Is it an expected behaviour for now?
if so, is there any plan to make its performance(speed&memory) as good as flash_attention or xformers without attention bias?
We don't plan to implement it ourselves at the moment. However, it seems to be on @tridao 's roadmap
[May 2023] Support attention bias (e.g. ALiBi, relative positional encoding).
Once he implements it, we will make it work in xFormers (as we can use Flash-Attention under the hood)
@danthe3rd I also need alibi support. for now, I pass
bias = LowerTriangularMaskWithTensorBias(alibi_bias)
toxops.memory_efficient_attention(..., attn_bias=bias )
. The forward only is ok, but failed at backward in training mode. Is it an expected behaviour for now?if so, is there any plan to make its performance(speed&memory) as good as flash_attention or xformers without attention bias?
I met the same issue. I wanted to left pad my sequences, and thus I used 'LowerTriangularMaskWithTensorBias' --> failed at backward. Is there any way for using this flexible & optimized mask for training?
However, it seems to be on @tridao 's roadmap
It looks like it's no longer on the roadmap.
On our side, we don't plan to implement that on the xFormers team, as researchers mostly use Rope embeddings rather than Alibi here. I know PyTorch was considering adding support for this, but I'm not sure what they decided, and whether or not this will include the backward pass (cc @drisspg )
Is there any way for using this flexible & optimized mask for training?
I assume that your bias is learnable right?
@danthe3rd I also need alibi support. for now, I pass
bias = LowerTriangularMaskWithTensorBias(alibi_bias)
toxops.memory_efficient_attention(..., attn_bias=bias )
. The forward only is ok, but failed at backward in training mode. Is it an expected behaviour for now?if so, is there any plan to make its performance(speed&memory) as good as flash_attention or xformers without attention bias?
@hiyijian Hi, have you seen any improvement in inference performance when using LowerTriangularMaskWithTensorBias(alibi_bias)
? In my tests, the speed has actually decreased. Here are the details of my environment:
- GPU:A100-80G
- xformers: 0.0.21.dev574
- torch: 2.0.1
- triton: 2.0.0
@danthe3rd I also need alibi support. for now, I pass
bias = LowerTriangularMaskWithTensorBias(alibi_bias)
toxops.memory_efficient_attention(..., attn_bias=bias )
. The forward only is ok, but failed at backward in training mode. Is it an expected behaviour for now?if so, is there any plan to make its performance(speed&memory) as good as flash_attention or xformers without attention bias?
@hiyijian Thanks for your practice! Have you used kv cache? I found that the result was wrong after opening the kv cache.
Hi @danthe3rd Thanks for the very quick response!
- Indeed, it is a training iteration including backward pass, optimizer step, etc. (see the linked
train.py
script with the training loop.). In a seperate script with some random q, k, v data I only benchmarked a simply forward pass through the attention functions. There thexformers
method was significantly faster. So I would suspect the backward pass has something to do with it and the attention bias as well.- Shapes of q, k, v are
torch.Size([2, 8, 2048, 64])
in [b, h, n, d].- I train with
torch.float32
in all cases.(a) Using
attn_bias = xformers.ops.LowerTriangularMask()
speeds it up to0.93it/s
and reduces memory to3518MiB
. The speed-up is especially surprising, since it is now faster than passing no mask at all... However, practically this won't be that useful, since I need to apply specific bias tensors, e.g. the alibi positional bias in my actual training runs. Also, if pooling is necessary I need to alter the bias by adding-inf
for all padded tokens. So, ideally I would like to pass a complete custom mask to not lose flexibility. The few performant mask implementations likexformers.ops.LowerTriangularMask()
unfortunately don't cover all the relevant cases...(b) True, but in that script the input q, k, v dimension is [b, h, n, d]. So in oder to keep it aligned I had to rearrange the shape. However, even if remove the traspositions I get the same performance results, so the impact seems to be neglectable.
Btw. I'm using the pre-release of
xformers
(installed via pip) that is compatible withpytorch 2.0
.
I also would like to know how to efficiently utilize a custom mask rather than using the predefined mask. I am trying to update the mask during training where the mask is not learnable. However, I notice that the mask still has gradient after passing to the xops.memory_efficient_attention
and the GPU memory increases a lot. I guess it is because the mask is a tensor, not a xformers.ops.LowerTriangularMask()
like attn_bias? Any idea on how to solve it?