flash-cosine-sim-attention
flash-cosine-sim-attention copied to clipboard
GPU Benchmarks
Hi Phil,
Firstly, Thank you for the amazing work yet again!
I was wondering if you had done any benchmarking with mid-tier GPUs. I ran the benchmarks on my local system with a few RTX 3090s and received these results:
python3 benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.96x kernel: 0.23ms baseline: 0.24ms seq_len: 256 slower: 1.32x kernel: 0.38ms baseline: 0.28ms seq_len: 512 slower: 1.85x kernel: 0.82ms baseline: 0.44ms seq_len: 1024 slower: 1.57x kernel: 2.15ms baseline: 1.37ms seq_len: 2048 slower: 1.17x kernel: 5.94ms baseline: 5.06ms seq_len: 4096 slower: 1.20x kernel: 22.70ms baseline: 18.84ms seq_len: 8192 slower: 0.00x kernel: 90.47ms baseline: oom
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.72x kernel: 0.19ms baseline: 0.26ms seq_len: 256 slower: 1.04x kernel: 0.24ms baseline: 0.23ms seq_len: 512 slower: 1.04x kernel: 0.30ms baseline: 0.29ms seq_len: 1024 slower: 1.00x kernel: 0.70ms baseline: 0.70ms seq_len: 2048 slower: 0.71x kernel: 1.83ms baseline: 2.59ms seq_len: 4096 slower: 0.67x kernel: 6.23ms baseline: 9.36ms seq_len: 8192 slower: 0.65x kernel: 23.78ms baseline: 36.45ms**
python3 benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.96x kernel: 0.55ms baseline: 0.57ms seq_len: 256 slower: 1.76x kernel: 0.89ms baseline: 0.50ms seq_len: 512 slower: 2.18x kernel: 2.09ms baseline: 0.96ms seq_len: 1024 slower: 1.83x kernel: 5.16ms baseline: 2.82ms seq_len: 2048 slower: 1.74x kernel: 17.56ms baseline: 10.12ms seq_len: 4096 slower: 1.71x kernel: 64.56ms baseline: 37.74ms seq_len: 8192 slower: 0.00x kernel: 250.87ms baseline: oom
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.92x kernel: 0.55ms baseline: 0.60ms seq_len: 256 slower: 1.03x kernel: 0.60ms baseline: 0.58ms seq_len: 512 slower: 1.54x kernel: 0.89ms baseline: 0.58ms seq_len: 1024 slower: 1.34x kernel: 2.03ms baseline: 1.52ms seq_len: 2048 slower: 1.20x kernel: 6.06ms baseline: 5.04ms seq_len: 4096 slower: 1.25x kernel: 23.19ms baseline: 18.58ms seq_len: 8192 slower: 1.22x kernel: 90.73ms baseline: 74.51ms
Is the speedup only seen on A100s?
I am going to train a small model on Wikitext-103 on an A100 cluster next and report the results.
Thank you,
Enrico
@conceptofmind Hey Enrico! :wave:
Thanks but most of the credit goes to @ahennequ
half of the code in the repository comes from https://github.com/ahennequ/pytorch-custom-mma
Yup, it is expected to be slower on older graphic cards. Even the MetaAI folks, who are working on the regular flash attention had a slower backwards pass on the order of 1.5-2x slower
However, on A100, where shared memory is plentiful , it should be a lot faster. Forwards pass Arthur already demonstrated to be much faster at his repository. I haven't validated backwards pass to be faster yet, but that would be good to check if you have access to them!
@conceptofmind also, try the autoregressive mode python benchmark.py --causal you should see a 2x speedup yet again since the CUDA code is able to avoid any computation on the upper triangular mask. key padding as well, depending on how much is masked
@lucidrains Thank you for the additional info and for confirming that it is expected to be slower. And thanks to @ahennequ as well!
I will run all of the benchmarks on an A100 and work on further validating the speed performance for the backward pass.
For consistency, here are the results on an RTX 3090 for python benchmark.py --causal:
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.84x kernel: 0.77ms baseline: 0.91ms seq_len: 256 slower: 1.29x kernel: 1.13ms baseline: 0.87ms seq_len: 512 slower: 1.03x kernel: 1.86ms baseline: 1.80ms seq_len: 1024 slower: 0.80x kernel: 4.38ms baseline: 5.44ms seq_len: 2048 slower: 0.62x kernel: 12.61ms baseline: 20.44ms seq_len: 4096 slower: 0.59x kernel: 45.52ms baseline: 77.76ms seq_len: 8192 slower: 0.00x kernel: 173.98ms baseline: oom
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.84x kernel: 0.75ms baseline: 0.90ms seq_len: 256 slower: 0.92x kernel: 0.80ms baseline: 0.87ms seq_len: 512 slower: 1.00x kernel: 1.02ms baseline: 1.02ms seq_len: 1024 slower: 0.59x kernel: 1.76ms baseline: 2.98ms seq_len: 2048 slower: 0.44x kernel: 4.58ms baseline: 10.44ms seq_len: 4096 slower: 0.37x kernel: 14.90ms baseline: 39.77ms seq_len: 8192 slower: 0.37x kernel: 57.78ms baseline: 156.87ms
I will post the benchmarks on an A100 as soon as I can.
Thank you again,
Enrico
Hi @lucidrains,
I ran the benchmarks on an instance with 4 A100 (40GB) GPUs.
Here are the results:
python3 benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.14x kernel: 0.31ms baseline: 0.27ms seq_len: 256 slower: 1.33x kernel: 0.44ms baseline: 0.33ms seq_len: 512 slower: 1.16x kernel: 0.88ms baseline: 0.76ms seq_len: 1024 slower: 1.10x kernel: 2.32ms baseline: 2.11ms seq_len: 2048 slower: 1.20x kernel: 6.09ms baseline: 5.06ms seq_len: 4096 slower: 0.86x kernel: 16.70ms baseline: 19.35ms seq_len: 8192 slower: 0.60x kernel: 62.88ms baseline: 105.39ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.86x kernel: 0.23ms baseline: 0.27ms seq_len: 256 slower: 1.23x kernel: 0.32ms baseline: 0.26ms seq_len: 512 slower: 1.24x kernel: 0.33ms baseline: 0.27ms seq_len: 1024 slower: 1.26x kernel: 0.74ms baseline: 0.59ms seq_len: 2048 slower: 1.21x kernel: 1.94ms baseline: 1.60ms seq_len: 4096 slower: 1.35x kernel: 8.58ms baseline: 6.37ms seq_len: 8192 slower: 0.90x kernel: 23.66ms baseline: 26.41ms
python3 benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.01x kernel: 0.86ms baseline: 0.85ms seq_len: 256 slower: 1.60x kernel: 1.32ms baseline: 0.83ms seq_len: 512 slower: 1.48x kernel: 2.31ms baseline: 1.56ms seq_len: 1024 slower: 1.39x kernel: 4.42ms baseline: 3.18ms seq_len: 2048 slower: 1.31x kernel: 14.29ms baseline: 10.90ms seq_len: 4096 slower: 1.24x kernel: 51.57ms baseline: 41.75ms seq_len: 8192 slower: 1.24x kernel: 201.62ms baseline: 163.18ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.89x kernel: 0.79ms baseline: 0.88ms seq_len: 256 slower: 0.98x kernel: 0.88ms baseline: 0.90ms seq_len: 512 slower: 1.28x kernel: 1.15ms baseline: 0.90ms seq_len: 1024 slower: 1.97x kernel: 1.96ms baseline: 1.00ms seq_len: 2048 slower: 1.68x kernel: 5.10ms baseline: 3.04ms seq_len: 4096 slower: 1.63x kernel: 17.90ms baseline: 11.00ms seq_len: 8192 slower: 1.62x kernel: 66.28ms baseline: 41.02ms
python3 benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.83x kernel: 1.21ms baseline: 1.46ms seq_len: 256 slower: 1.08x kernel: 1.47ms baseline: 1.36ms seq_len: 512 slower: 0.82x kernel: 2.45ms baseline: 2.97ms seq_len: 1024 slower: 0.67x kernel: 3.80ms baseline: 5.71ms seq_len: 2048 slower: 0.54x kernel: 10.23ms baseline: 18.79ms seq_len: 4096 slower: 0.49x kernel: 36.01ms baseline: 73.93ms seq_len: 8192 slower: 0.47x kernel: 135.94ms baseline: 290.59ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.78x kernel: 1.06ms baseline: 1.36ms seq_len: 256 slower: 0.83x kernel: 1.14ms baseline: 1.37ms seq_len: 512 slower: 1.00x kernel: 1.37ms baseline: 1.37ms seq_len: 1024 slower: 0.96x kernel: 2.04ms baseline: 2.13ms seq_len: 2048 slower: 0.62x kernel: 4.32ms baseline: 6.97ms seq_len: 4096 slower: 0.51x kernel: 13.36ms baseline: 26.21ms seq_len: 8192 slower: 0.46x kernel: 47.16ms baseline: 103.51ms
No OOM on forward pass, backward pass, or causal now. The performance was still slower on the benchmarks.
I will try an A100 (80 GB) instance as well.
System info for cloud instance:
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2021 NVIDIA Corporation Built on Mon_May__3_19:15:13_PDT_2021 Cuda compilation tools, release 11.3, V11.3.109 Build cuda_11.3.r11.3/compiler.29920130_0
+-----------------------------------------------------------------------------+ | NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A100-PCI... On | 00000000:05:00.0 Off | 0 | | N/A 32C P0 41W / 250W | 0MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 1 NVIDIA A100-PCI... On | 00000000:06:00.0 Off | 0 | | N/A 32C P0 47W / 250W | 0MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 2 NVIDIA A100-PCI... On | 00000000:45:00.0 Off | 0 | | N/A 33C P0 45W / 250W | 0MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 3 NVIDIA A100-SXM... On | 00000000:46:00.0 Off | 0 | | N/A 33C P0 25W / 400W | 0MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+
Thank you,
Enrico
@lucidrains
I have also ran the benchmarks on 2 A100 (80GB) GPUs:
python3 benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.12x kernel: 0.28ms baseline: 0.25ms seq_len: 256 slower: 1.28x kernel: 0.33ms baseline: 0.26ms seq_len: 512 slower: 1.16x kernel: 0.59ms baseline: 0.50ms seq_len: 1024 slower: 1.11x kernel: 1.46ms baseline: 1.31ms seq_len: 2048 slower: 1.03x kernel: 4.64ms baseline: 4.49ms seq_len: 4096 slower: 0.94x kernel: 16.66ms baseline: 17.79ms seq_len: 8192 slower: 0.90x kernel: 62.53ms baseline: 69.56ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.88x kernel: 0.22ms baseline: 0.26ms seq_len: 256 slower: 1.00x kernel: 0.26ms baseline: 0.26ms seq_len: 512 slower: 1.29x kernel: 0.33ms baseline: 0.26ms seq_len: 1024 slower: 1.42x kernel: 0.73ms baseline: 0.52ms seq_len: 2048 slower: 1.31x kernel: 1.93ms baseline: 1.48ms seq_len: 4096 slower: 1.13x kernel: 6.49ms baseline: 5.73ms seq_len: 8192 slower: 0.98x kernel: 23.32ms baseline: 23.74ms
python3 benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.00x kernel: 1.07ms baseline: 1.06ms seq_len: 256 slower: 1.09x kernel: 1.29ms baseline: 1.18ms seq_len: 512 slower: 1.63x kernel: 1.94ms baseline: 1.19ms seq_len: 1024 slower: 1.38x kernel: 4.15ms baseline: 3.01ms seq_len: 2048 slower: 1.43x kernel: 14.21ms baseline: 9.94ms seq_len: 4096 slower: 1.34x kernel: 51.12ms baseline: 38.18ms seq_len: 8192 slower: 1.34x kernel: 200.37ms baseline: 149.35ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.90x kernel: 1.02ms baseline: 1.14ms seq_len: 256 slower: 0.98x kernel: 1.02ms baseline: 1.04ms seq_len: 512 slower: 1.25x kernel: 1.30ms baseline: 1.04ms seq_len: 1024 slower: 2.06x kernel: 2.13ms baseline: 1.03ms seq_len: 2048 slower: 1.95x kernel: 5.09ms baseline: 2.60ms seq_len: 4096 slower: 1.90x kernel: 17.74ms baseline: 9.33ms seq_len: 8192 slower: 1.90x kernel: 65.94ms baseline: 34.62ms
python3 benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.79x kernel: 1.40ms baseline: 1.77ms seq_len: 256 slower: 0.82x kernel: 1.56ms baseline: 1.90ms seq_len: 512 slower: 1.02x kernel: 1.97ms baseline: 1.92ms seq_len: 1024 slower: 0.73x kernel: 3.70ms baseline: 5.07ms seq_len: 2048 slower: 0.58x kernel: 10.10ms baseline: 17.41ms seq_len: 4096 slower: 0.52x kernel: 35.69ms baseline: 68.34ms seq_len: 8192 slower: 0.51x kernel: 134.92ms baseline: 265.93ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.86x kernel: 1.46ms baseline: 1.69ms seq_len: 256 slower: 0.51x kernel: 0.87ms baseline: 1.70ms seq_len: 512 slower: 0.95x kernel: 1.63ms baseline: 1.71ms seq_len: 1024 slower: 1.08x kernel: 2.21ms baseline: 2.04ms seq_len: 2048 slower: 0.72x kernel: 4.45ms baseline: 6.17ms seq_len: 4096 slower: 0.57x kernel: 13.31ms baseline: 23.39ms seq_len: 8192 slower: 0.51x kernel: 46.79ms baseline: 92.37ms
In my tests, I am still seeing slightly slower performance on forward, backward, and causal passes.
System information for 2 A100 (80 GB) GPUs:
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2021 NVIDIA Corporation Built on Mon_May__3_19:15:13_PDT_2021 Cuda compilation tools, release 11.3, V11.3.109 Build cuda_11.3.r11.3/compiler.29920130_0
+-----------------------------------------------------------------------------+ | NVIDIA-SMI 470.141.03 Driver Version: 470.141.03 CUDA Version: 11.4 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A100 80G... Off | 00000000:05:00.0 Off | 0 | | N/A 33C P0 42W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 1 NVIDIA A100 80G... Off | 00000000:88:00.0 Off | 0 | | N/A 32C P0 43W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
Thank you again,
Enrico
@conceptofmind thank you for running this! 🙏 looks like non autoregressive still needs some work, but definitely close to the finish line (and still room to optimize)
@lucidrains Very close and still incredibly impressive! I will also test the gpt-2 model with the two A100 (80 GB) GPUs on wikitext-103 / enwiki8 and document the training results here as well.
Best,
Enrico
@conceptofmind Hey Enrico, I moved some of the peripheral python into c++, under recommendation of Arthur
If you have some time, do you think you could rerun the benchmarks on your 3090s or A100 and share what you see on your end? I will likely end up crafting custom kernels for A100 if the results are still not up-to-par
@lucidrains Of course. I will rerun each of the benchmarks on an RTX 3090, A100 (40 GB), and A100 (80 GB) and document the results here. I am still running the gpt-2 model on wikitext-103.
Best,
Enrico
@lucidrains Here are the results for the new RTX 3090 benchmark run:
python3 benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.97x kernel: 0.23ms baseline: 0.24ms seq_len: 256 slower: 2.20x kernel: 0.53ms baseline: 0.24ms seq_len: 512 slower: 1.72x kernel: 0.81ms baseline: 0.47ms seq_len: 1024 slower: 1.66x kernel: 2.21ms baseline: 1.34ms seq_len: 2048 slower: 1.18x kernel: 5.94ms baseline: 5.01ms seq_len: 4096 slower: 1.19x kernel: 22.51ms baseline: 18.92ms seq_len: 8192 slower: 0.00x kernel: 90.21ms baseline: oom
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.80x kernel: 0.20ms baseline: 0.26ms seq_len: 256 slower: 1.03x kernel: 0.24ms baseline: 0.23ms seq_len: 512 slower: 1.04x kernel: 0.30ms baseline: 0.29ms seq_len: 1024 slower: 0.99x kernel: 0.70ms baseline: 0.70ms seq_len: 2048 slower: 0.70x kernel: 1.83ms baseline: 2.63ms seq_len: 4096 slower: 0.68x kernel: 6.38ms baseline: 9.34ms seq_len: 8192 slower: 0.66x kernel: 23.78ms baseline: 36.23ms
python3 benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.95x kernel: 0.55ms baseline: 0.58ms seq_len: 256 slower: 1.67x kernel: 0.98ms baseline: 0.58ms seq_len: 512 slower: 2.21x kernel: 2.11ms baseline: 0.96ms seq_len: 1024 slower: 1.82x kernel: 5.12ms baseline: 2.81ms seq_len: 2048 slower: 1.73x kernel: 17.39ms baseline: 10.08ms seq_len: 4096 slower: 1.71x kernel: 64.23ms baseline: 37.59ms seq_len: 8192 slower: 0.00x kernel: 251.81ms baseline: oom
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.92x kernel: 0.50ms baseline: 0.54ms seq_len: 256 slower: 1.08x kernel: 0.60ms baseline: 0.55ms seq_len: 512 slower: 1.57x kernel: 0.90ms baseline: 0.57ms seq_len: 1024 slower: 1.33x kernel: 2.03ms baseline: 1.53ms seq_len: 2048 slower: 1.20x kernel: 6.06ms baseline: 5.03ms seq_len: 4096 slower: 1.25x kernel: 23.26ms baseline: 18.61ms seq_len: 8192 slower: 1.27x kernel: 90.30ms baseline: 71.28ms
python3 benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.86x kernel: 0.78ms baseline: 0.91ms seq_len: 256 slower: 1.24x kernel: 1.26ms baseline: 1.01ms seq_len: 512 slower: 1.02x kernel: 1.88ms baseline: 1.85ms seq_len: 1024 slower: 0.73x kernel: 3.93ms baseline: 5.42ms seq_len: 2048 slower: 0.62x kernel: 12.58ms baseline: 20.45ms seq_len: 4096 slower: 0.58x kernel: 45.40ms baseline: 77.65ms seq_len: 8192 slower: 0.00x kernel: 177.48ms baseline: oom
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.88x kernel: 0.82ms baseline: 0.93ms seq_len: 256 slower: 0.97x kernel: 0.92ms baseline: 0.95ms seq_len: 512 slower: 1.04x kernel: 1.06ms baseline: 1.02ms seq_len: 1024 slower: 0.60x kernel: 1.80ms baseline: 3.00ms seq_len: 2048 slower: 0.43x kernel: 5.06ms baseline: 11.84ms seq_len: 4096 slower: 0.33x kernel: 14.87ms baseline: 44.85ms seq_len: 8192 slower: 0.34x kernel: 57.54ms baseline: 171.53ms
I am benchmarking on the A100s now and will update you soon.
@lucidrains Here are the results for the new A100 (80 GB) benchmark run:
python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 2.75x kernel: 0.67ms baseline: 0.24ms seq_len: 256 slower: 1.47x kernel: 0.44ms baseline: 0.30ms seq_len: 512 slower: 1.17x kernel: 0.62ms baseline: 0.53ms seq_len: 1024 slower: 1.02x kernel: 1.50ms baseline: 1.47ms seq_len: 2048 slower: 1.04x kernel: 4.70ms baseline: 4.54ms seq_len: 4096 slower: 0.94x kernel: 16.86ms baseline: 17.95ms seq_len: 8192 slower: 0.83x kernel: 63.45ms baseline: 76.31ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.85x kernel: 0.23ms baseline: 0.27ms seq_len: 256 slower: 8.96x kernel: 3.14ms baseline: 0.35ms seq_len: 512 slower: 1.74x kernel: 0.52ms baseline: 0.30ms seq_len: 1024 slower: 7.83x kernel: 3.82ms baseline: 0.49ms seq_len: 2048 slower: 1.18x kernel: 1.76ms baseline: 1.50ms seq_len: 4096 slower: 1.00x kernel: 5.78ms baseline: 5.79ms seq_len: 8192 slower: 0.89x kernel: 21.26ms baseline: 23.88ms
python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.90x kernel: 0.96ms baseline: 1.07ms seq_len: 256 slower: 1.11x kernel: 1.25ms baseline: 1.13ms seq_len: 512 slower: 1.48x kernel: 1.74ms baseline: 1.18ms seq_len: 1024 slower: 1.33x kernel: 4.00ms baseline: 3.01ms seq_len: 2048 slower: 1.42x kernel: 14.12ms baseline: 9.94ms seq_len: 4096 slower: 1.33x kernel: 50.84ms baseline: 38.27ms seq_len: 8192 slower: 1.33x kernel: 199.24ms baseline: 149.77ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.69x kernel: 0.71ms baseline: 1.03ms seq_len: 256 slower: 0.82x kernel: 0.74ms baseline: 0.90ms seq_len: 512 slower: 1.21x kernel: 1.38ms baseline: 1.14ms seq_len: 1024 slower: 2.49x kernel: 2.78ms baseline: 1.11ms seq_len: 2048 slower: 1.92x kernel: 5.03ms baseline: 2.62ms seq_len: 4096 slower: 1.88x kernel: 17.66ms baseline: 9.39ms seq_len: 8192 slower: 1.90x kernel: 65.91ms baseline: 34.73ms
python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.94x kernel: 1.47ms baseline: 1.57ms seq_len: 256 slower: 0.80x kernel: 1.41ms baseline: 1.76ms seq_len: 512 slower: 1.01x kernel: 1.98ms baseline: 1.96ms seq_len: 1024 slower: 0.75x kernel: 3.78ms baseline: 5.07ms seq_len: 2048 slower: 0.58x kernel: 10.11ms baseline: 17.36ms seq_len: 4096 slower: 0.52x kernel: 35.61ms baseline: 68.20ms seq_len: 8192 slower: 0.51x kernel: 134.62ms baseline: 266.14ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.76x kernel: 1.37ms baseline: 1.81ms seq_len: 256 slower: 0.83x kernel: 1.45ms baseline: 1.74ms seq_len: 512 slower: 0.86x kernel: 1.40ms baseline: 1.63ms seq_len: 1024 slower: 1.10x kernel: 2.31ms baseline: 2.10ms seq_len: 2048 slower: 0.74x kernel: 4.52ms baseline: 6.14ms seq_len: 4096 slower: 0.56x kernel: 13.00ms baseline: 23.35ms seq_len: 8192 slower: 0.50x kernel: 45.81ms baseline: 92.24ms
Causal is quite close!
System info:
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2021 NVIDIA Corporation Built on Mon_May__3_19:15:13_PDT_2021 Cuda compilation tools, release 11.3, V11.3.109 Build cuda_11.3.r11.3/compiler.29920130_0
+-----------------------------------------------------------------------------+ | NVIDIA-SMI 470.141.03 Driver Version: 470.141.03 CUDA Version: 11.4 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A100 80G... Off | 00000000:43:00.0 Off | 0 | | N/A 38C P0 65W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
@conceptofmind ohh yea, that doesn't look great and A100 definitely needs more work
thank you for running the benchmarks!
this line in the benchmark
seq_len: 1024 slower: 7.83x kernel: 3.82ms baseline: 0.49mslooks really strangewere the benchmarks done on GPUs that are idle?
@lucidrains I tested on both Vast.ai and Coreweave cloud instances.
For each benchmark, I created a brand new cloud instance. I only installed miniconda for managing the env and PyTorch through conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch.
All of the GPUs I am using should be idle since I only ran the benchmarks. I will test on a different A100 in case there is an issue with one of their GPUs. Nvidia-smi came up clean when tracking GPU usage.
I can test on GCP as well.
@conceptofmind thanks Enrico!
@lucidrains Hi Phil,
I am testing on 8 different A100 (80 GB) devices. I will show the benchmarks for each device. Also, I forgot that I need extended permissions for 80 GB devices on GCP. I requested access to those. I will try Coreweave again later too.
+-----------------------------------------------------------------------------+ | NVIDIA-SMI 470.141.03 Driver Version: 470.141.03 CUDA Version: 11.4 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A100 80G... Off | 00000000:04:00.0 Off | 0 | | N/A 36C P0 55W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 1 NVIDIA A100 80G... Off | 00000000:05:00.0 Off | 0 | | N/A 40C P0 65W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 2 NVIDIA A100 80G... Off | 00000000:43:00.0 Off | 0 | | N/A 37C P0 64W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 3 NVIDIA A100 80G... Off | 00000000:44:00.0 Off | 0 | | N/A 39C P0 70W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 4 NVIDIA A100 80G... Off | 00000000:87:00.0 Off | 0 | | N/A 37C P0 65W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 5 NVIDIA A100 80G... Off | 00000000:88:00.0 Off | 0 | | N/A 39C P0 65W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 6 NVIDIA A100 80G... Off | 00000000:C3:00.0 Off | 0 | | N/A 37C P0 62W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 7 NVIDIA A100 80G... Off | 00000000:C4:00.0 Off | 0 | | N/A 41C P0 67W / 300W | 0MiB / 80994MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
CUDA DEVICE 0
CUDA_VISIBLE_DEVICES=0 python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 2.14x kernel: 0.52ms baseline: 0.24ms seq_len: 256 slower: 1.52x kernel: 0.45ms baseline: 0.29ms seq_len: 512 slower: 1.17x kernel: 0.66ms baseline: 0.57ms seq_len: 1024 slower: 1.06x kernel: 1.51ms baseline: 1.42ms seq_len: 2048 slower: 1.05x kernel: 4.73ms baseline: 4.52ms seq_len: 4096 slower: 0.94x kernel: 16.89ms baseline: 18.01ms seq_len: 8192 slower: 0.83x kernel: 63.64ms baseline: 76.48ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.83x kernel: 0.24ms baseline: 0.30ms seq_len: 256 slower: 0.99x kernel: 0.30ms baseline: 0.30ms seq_len: 512 slower: 1.19x kernel: 0.35ms baseline: 0.30ms seq_len: 1024 slower: 1.26x kernel: 0.61ms baseline: 0.48ms seq_len: 2048 slower: 1.19x kernel: 1.76ms baseline: 1.48ms seq_len: 4096 slower: 1.00x kernel: 5.79ms baseline: 5.80ms seq_len: 8192 slower: 0.89x kernel: 21.32ms baseline: 24.03ms
CUDA_VISIBLE_DEVICES=0 python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.02x kernel: 0.94ms baseline: 0.92ms seq_len: 256 slower: 1.11x kernel: 0.95ms baseline: 0.86ms seq_len: 512 slower: 1.57x kernel: 1.83ms baseline: 1.17ms seq_len: 1024 slower: 1.28x kernel: 3.86ms baseline: 3.01ms seq_len: 2048 slower: 1.42x kernel: 14.14ms baseline: 9.99ms seq_len: 4096 slower: 1.32x kernel: 50.93ms baseline: 38.46ms seq_len: 8192 slower: 1.33x kernel: 199.75ms baseline: 150.70ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.88x kernel: 0.89ms baseline: 1.02ms seq_len: 256 slower: 0.89x kernel: 0.75ms baseline: 0.85ms seq_len: 512 slower: 1.24x kernel: 1.10ms baseline: 0.88ms seq_len: 1024 slower: 1.85x kernel: 2.10ms baseline: 1.13ms seq_len: 2048 slower: 1.94x kernel: 5.09ms baseline: 2.62ms seq_len: 4096 slower: 1.89x kernel: 17.72ms baseline: 9.40ms seq_len: 8192 slower: 1.90x kernel: 66.00ms baseline: 34.79ms
CUDA_VISIBLE_DEVICES=0 python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.79x kernel: 1.41ms baseline: 1.79ms seq_len: 256 slower: 0.97x kernel: 1.06ms baseline: 1.09ms seq_len: 512 slower: 1.00x kernel: 1.97ms baseline: 1.97ms seq_len: 1024 slower: 0.71x kernel: 3.62ms baseline: 5.09ms seq_len: 2048 slower: 0.58x kernel: 10.17ms baseline: 17.45ms seq_len: 4096 slower: 0.52x kernel: 35.78ms baseline: 68.47ms seq_len: 8192 slower: 0.50x kernel: 135.12ms baseline: 267.71ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.77x kernel: 1.28ms baseline: 1.67ms seq_len: 256 slower: 0.83x kernel: 1.10ms baseline: 1.33ms seq_len: 512 slower: 1.01x kernel: 1.33ms baseline: 1.32ms seq_len: 1024 slower: 1.03x kernel: 2.06ms baseline: 1.99ms seq_len: 2048 slower: 0.69x kernel: 4.27ms baseline: 6.20ms seq_len: 4096 slower: 0.56x kernel: 13.10ms baseline: 23.47ms seq_len: 8192 slower: 0.50x kernel: 45.88ms baseline: 92.68ms
CUDA DEVICE 1
CUDA_VISIBLE_DEVICES=1 python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.11x kernel: 0.27ms baseline: 0.25ms seq_len: 256 slower: 1.22x kernel: 0.36ms baseline: 0.29ms seq_len: 512 slower: 1.19x kernel: 0.62ms baseline: 0.52ms seq_len: 1024 slower: 1.09x kernel: 1.51ms baseline: 1.39ms seq_len: 2048 slower: 1.05x kernel: 4.71ms baseline: 4.48ms seq_len: 4096 slower: 0.20x kernel: 16.89ms baseline: 86.22ms seq_len: 8192 slower: 1.17x kernel: 81.63ms baseline: 69.52ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.85x kernel: 0.21ms baseline: 0.25ms seq_len: 256 slower: 1.07x kernel: 0.32ms baseline: 0.30ms seq_len: 512 slower: 1.19x kernel: 0.36ms baseline: 0.30ms seq_len: 1024 slower: 1.26x kernel: 0.61ms baseline: 0.49ms seq_len: 2048 slower: 1.18x kernel: 1.75ms baseline: 1.48ms seq_len: 4096 slower: 1.01x kernel: 5.78ms baseline: 5.73ms seq_len: 8192 slower: 0.89x kernel: 21.23ms baseline: 23.78ms
CUDA_VISIBLE_DEVICES=1 python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.89x kernel: 0.99ms baseline: 1.11ms seq_len: 256 slower: 1.18x kernel: 1.29ms baseline: 1.09ms seq_len: 512 slower: 1.55x kernel: 1.81ms baseline: 1.16ms seq_len: 1024 slower: 1.33x kernel: 4.00ms baseline: 3.01ms seq_len: 2048 slower: 1.42x kernel: 14.06ms baseline: 9.94ms seq_len: 4096 slower: 1.33x kernel: 50.85ms baseline: 38.23ms seq_len: 8192 slower: 1.33x kernel: 199.25ms baseline: 149.76ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.79x kernel: 0.92ms baseline: 1.16ms seq_len: 256 slower: 0.92x kernel: 0.92ms baseline: 1.00ms seq_len: 512 slower: 1.17x kernel: 1.36ms baseline: 1.17ms seq_len: 1024 slower: 1.85x kernel: 2.10ms baseline: 1.13ms seq_len: 2048 slower: 1.94x kernel: 5.07ms baseline: 2.61ms seq_len: 4096 slower: 1.89x kernel: 17.67ms baseline: 9.37ms seq_len: 8192 slower: 1.90x kernel: 65.81ms baseline: 34.69ms
CUDA_VISIBLE_DEVICES=1 python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.91x kernel: 1.48ms baseline: 1.63ms seq_len: 256 slower: 0.60x kernel: 1.03ms baseline: 1.72ms seq_len: 512 slower: 0.96x kernel: 1.78ms baseline: 1.86ms seq_len: 1024 slower: 0.74x kernel: 3.75ms baseline: 5.07ms seq_len: 2048 slower: 0.58x kernel: 10.14ms baseline: 17.37ms seq_len: 4096 slower: 0.52x kernel: 35.62ms baseline: 68.21ms seq_len: 8192 slower: 0.51x kernel: 134.61ms baseline: 266.15ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.67x kernel: 1.23ms baseline: 1.83ms seq_len: 256 slower: 0.66x kernel: 1.19ms baseline: 1.81ms seq_len: 512 slower: 0.88x kernel: 1.74ms baseline: 1.99ms seq_len: 1024 slower: 1.09x kernel: 2.29ms baseline: 2.10ms seq_len: 2048 slower: 0.74x kernel: 4.58ms baseline: 6.17ms seq_len: 4096 slower: 0.56x kernel: 13.05ms baseline: 23.35ms seq_len: 8192 slower: 0.49x kernel: 45.66ms baseline: 92.26ms
CUDA DEVICE 2
CUDA_VISIBLE_DEVICES=2 python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.01x kernel: 0.25ms baseline: 0.24ms seq_len: 256 slower: 1.70x kernel: 0.50ms baseline: 0.29ms seq_len: 512 slower: 1.21x kernel: 0.63ms baseline: 0.52ms seq_len: 1024 slower: 1.05x kernel: 1.51ms baseline: 1.43ms seq_len: 2048 slower: 1.04x kernel: 4.71ms baseline: 4.54ms seq_len: 4096 slower: 0.94x kernel: 16.87ms baseline: 17.86ms seq_len: 8192 slower: 0.85x kernel: 63.45ms baseline: 75.07ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.85x kernel: 0.23ms baseline: 0.26ms seq_len: 256 slower: 0.96x kernel: 0.28ms baseline: 0.29ms seq_len: 512 slower: 1.20x kernel: 0.35ms baseline: 0.29ms seq_len: 1024 slower: 1.26x kernel: 0.61ms baseline: 0.48ms seq_len: 2048 slower: 1.18x kernel: 1.76ms baseline: 1.49ms seq_len: 4096 slower: 1.00x kernel: 5.78ms baseline: 5.77ms seq_len: 8192 slower: 0.89x kernel: 21.23ms baseline: 23.84ms
CUDA_VISIBLE_DEVICES=2 python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.07x kernel: 1.02ms baseline: 0.95ms seq_len: 256 slower: 1.10x kernel: 1.29ms baseline: 1.18ms seq_len: 512 slower: 1.64x kernel: 1.93ms baseline: 1.18ms seq_len: 1024 slower: 1.33x kernel: 4.02ms baseline: 3.01ms seq_len: 2048 slower: 1.42x kernel: 14.13ms baseline: 9.94ms seq_len: 4096 slower: 1.33x kernel: 50.84ms baseline: 38.30ms seq_len: 8192 slower: 1.33x kernel: 199.32ms baseline: 149.81ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.87x kernel: 0.97ms baseline: 1.11ms seq_len: 256 slower: 0.76x kernel: 0.68ms baseline: 0.89ms seq_len: 512 slower: 1.18x kernel: 1.33ms baseline: 1.13ms seq_len: 1024 slower: 1.86x kernel: 2.03ms baseline: 1.09ms seq_len: 2048 slower: 1.92x kernel: 5.03ms baseline: 2.61ms seq_len: 4096 slower: 1.88x kernel: 17.68ms baseline: 9.39ms seq_len: 8192 slower: 1.91x kernel: 66.38ms baseline: 34.74ms
CUDA_VISIBLE_DEVICES=2 python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.74x kernel: 1.41ms baseline: 1.91ms seq_len: 256 slower: 1.25x kernel: 1.55ms baseline: 1.23ms seq_len: 512 slower: 0.99x kernel: 1.96ms baseline: 1.97ms seq_len: 1024 slower: 0.73x kernel: 3.72ms baseline: 5.07ms seq_len: 2048 slower: 0.58x kernel: 10.10ms baseline: 17.37ms seq_len: 4096 slower: 0.52x kernel: 35.61ms baseline: 68.20ms seq_len: 8192 slower: 0.51x kernel: 134.67ms baseline: 266.21ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.72x kernel: 1.29ms baseline: 1.81ms seq_len: 256 slower: 0.78x kernel: 1.09ms baseline: 1.39ms seq_len: 512 slower: 0.99x kernel: 1.39ms baseline: 1.41ms seq_len: 1024 slower: 1.08x kernel: 2.24ms baseline: 2.08ms seq_len: 2048 slower: 0.73x kernel: 4.46ms baseline: 6.15ms seq_len: 4096 slower: 0.56x kernel: 12.99ms baseline: 23.34ms seq_len: 8192 slower: 0.50x kernel: 45.79ms baseline: 92.24ms
CUDA DEVICE 3
CUDA_VISIBLE_DEVICES=3 python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.00x kernel: 0.25ms baseline: 0.25ms seq_len: 256 slower: 1.22x kernel: 0.36ms baseline: 0.29ms seq_len: 512 slower: 1.24x kernel: 0.65ms baseline: 0.52ms seq_len: 1024 slower: 1.07x kernel: 1.51ms baseline: 1.41ms seq_len: 2048 slower: 1.05x kernel: 4.72ms baseline: 4.52ms seq_len: 4096 slower: 0.94x kernel: 16.90ms baseline: 18.00ms seq_len: 8192 slower: 1.06x kernel: 74.91ms baseline: 70.35ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.86x kernel: 0.22ms baseline: 0.25ms seq_len: 256 slower: 0.95x kernel: 0.28ms baseline: 0.30ms seq_len: 512 slower: 1.30x kernel: 0.36ms baseline: 0.28ms seq_len: 1024 slower: 1.26x kernel: 0.61ms baseline: 0.49ms seq_len: 2048 slower: 1.17x kernel: 1.73ms baseline: 1.48ms seq_len: 4096 slower: 1.00x kernel: 5.77ms baseline: 5.78ms seq_len: 8192 slower: 0.89x kernel: 21.26ms baseline: 23.97ms
CUDA_VISIBLE_DEVICES=3 python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.88x kernel: 0.94ms baseline: 1.07ms seq_len: 256 slower: 1.09x kernel: 1.23ms baseline: 1.13ms seq_len: 512 slower: 1.57x kernel: 1.84ms baseline: 1.18ms seq_len: 1024 slower: 1.35x kernel: 4.05ms baseline: 3.01ms seq_len: 2048 slower: 1.42x kernel: 14.12ms baseline: 9.96ms seq_len: 4096 slower: 1.32x kernel: 50.89ms baseline: 38.44ms seq_len: 8192 slower: 1.32x kernel: 199.46ms baseline: 150.65ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.84x kernel: 1.00ms baseline: 1.20ms seq_len: 256 slower: 0.95x kernel: 1.10ms baseline: 1.16ms seq_len: 512 slower: 1.21x kernel: 1.07ms baseline: 0.88ms seq_len: 1024 slower: 1.91x kernel: 2.10ms baseline: 1.10ms seq_len: 2048 slower: 1.92x kernel: 5.05ms baseline: 2.62ms seq_len: 4096 slower: 1.87x kernel: 17.66ms baseline: 9.42ms seq_len: 8192 slower: 1.91x kernel: 66.37ms baseline: 34.81ms
CUDA_VISIBLE_DEVICES=3 python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.74x kernel: 1.36ms baseline: 1.83ms seq_len: 256 slower: 0.83x kernel: 1.45ms baseline: 1.75ms seq_len: 512 slower: 1.14x kernel: 2.08ms baseline: 1.83ms seq_len: 1024 slower: 0.64x kernel: 3.27ms baseline: 5.10ms seq_len: 2048 slower: 0.58x kernel: 10.13ms baseline: 17.45ms seq_len: 4096 slower: 0.52x kernel: 35.65ms baseline: 68.44ms seq_len: 8192 slower: 0.50x kernel: 134.73ms baseline: 267.22ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.69x kernel: 1.12ms baseline: 1.63ms seq_len: 256 slower: 0.78x kernel: 1.11ms baseline: 1.42ms seq_len: 512 slower: 0.82x kernel: 1.60ms baseline: 1.94ms seq_len: 1024 slower: 1.07x kernel: 2.23ms baseline: 2.09ms seq_len: 2048 slower: 0.72x kernel: 4.43ms baseline: 6.20ms seq_len: 4096 slower: 0.55x kernel: 13.01ms baseline: 23.49ms seq_len: 8192 slower: 0.49x kernel: 45.83ms baseline: 92.63ms
CUDA DEVICE 4
CUDA_VISIBLE_DEVICES=4 python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.47x kernel: 0.36ms baseline: 0.25ms seq_len: 256 slower: 1.21x kernel: 0.36ms baseline: 0.30ms seq_len: 512 slower: 1.25x kernel: 0.65ms baseline: 0.52ms seq_len: 1024 slower: 1.09x kernel: 1.51ms baseline: 1.38ms seq_len: 2048 slower: 1.05x kernel: 4.73ms baseline: 4.51ms seq_len: 4096 slower: 0.94x kernel: 16.85ms baseline: 17.88ms seq_len: 8192 slower: 0.91x kernel: 63.46ms baseline: 70.10ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.85x kernel: 0.22ms baseline: 0.25ms seq_len: 256 slower: 0.94x kernel: 0.29ms baseline: 0.30ms seq_len: 512 slower: 1.19x kernel: 0.36ms baseline: 0.30ms seq_len: 1024 slower: 1.26x kernel: 0.61ms baseline: 0.49ms seq_len: 2048 slower: 1.18x kernel: 1.74ms baseline: 1.48ms seq_len: 4096 slower: 1.00x kernel: 5.77ms baseline: 5.77ms seq_len: 8192 slower: 0.89x kernel: 21.28ms baseline: 23.94ms
CUDA_VISIBLE_DEVICES=4 python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.97x kernel: 1.08ms baseline: 1.12ms seq_len: 256 slower: 1.12x kernel: 0.99ms baseline: 0.88ms seq_len: 512 slower: 1.48x kernel: 1.67ms baseline: 1.13ms seq_len: 1024 slower: 1.30x kernel: 3.92ms baseline: 3.02ms seq_len: 2048 slower: 1.42x kernel: 14.15ms baseline: 9.96ms seq_len: 4096 slower: 1.33x kernel: 50.86ms baseline: 38.39ms seq_len: 8192 slower: 1.33x kernel: 199.24ms baseline: 150.33ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.64x kernel: 0.66ms baseline: 1.03ms seq_len: 256 slower: 0.60x kernel: 0.68ms baseline: 1.12ms seq_len: 512 slower: 1.27x kernel: 1.12ms baseline: 0.88ms seq_len: 1024 slower: 1.95x kernel: 2.12ms baseline: 1.09ms seq_len: 2048 slower: 1.92x kernel: 5.04ms baseline: 2.62ms seq_len: 4096 slower: 1.88x kernel: 17.65ms baseline: 9.41ms seq_len: 8192 slower: 1.91x kernel: 66.44ms baseline: 34.80ms
CUDA_VISIBLE_DEVICES=4 python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.95x kernel: 1.51ms baseline: 1.59ms seq_len: 256 slower: 0.57x kernel: 1.00ms baseline: 1.75ms seq_len: 512 slower: 0.94x kernel: 1.86ms baseline: 1.99ms seq_len: 1024 slower: 0.73x kernel: 3.72ms baseline: 5.07ms seq_len: 2048 slower: 0.58x kernel: 10.12ms baseline: 17.40ms seq_len: 4096 slower: 0.52x kernel: 35.63ms baseline: 68.31ms seq_len: 8192 slower: 0.50x kernel: 134.81ms baseline: 267.01ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.82x kernel: 1.33ms baseline: 1.63ms seq_len: 256 slower: 0.71x kernel: 1.16ms baseline: 1.65ms seq_len: 512 slower: 0.87x kernel: 1.63ms baseline: 1.88ms seq_len: 1024 slower: 1.09x kernel: 2.28ms baseline: 2.09ms seq_len: 2048 slower: 0.74x kernel: 4.56ms baseline: 6.17ms seq_len: 4096 slower: 0.56x kernel: 12.99ms baseline: 23.38ms seq_len: 8192 slower: 0.49x kernel: 45.83ms baseline: 92.60ms
CUDA DEVICE 5
CUDA_VISIBLE_DEVICES=5 python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.96x kernel: 0.24ms baseline: 0.25ms seq_len: 256 slower: 1.17x kernel: 0.34ms baseline: 0.29ms seq_len: 512 slower: 1.22x kernel: 0.64ms baseline: 0.52ms seq_len: 1024 slower: 1.09x kernel: 1.51ms baseline: 1.39ms seq_len: 2048 slower: 1.05x kernel: 4.71ms baseline: 4.50ms seq_len: 4096 slower: 0.20x kernel: 16.89ms baseline: 86.13ms seq_len: 8192 slower: 1.06x kernel: 74.11ms baseline: 69.78ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.85x kernel: 0.22ms baseline: 0.25ms seq_len: 256 slower: 1.16x kernel: 0.35ms baseline: 0.30ms seq_len: 512 slower: 1.18x kernel: 0.35ms baseline: 0.30ms seq_len: 1024 slower: 1.25x kernel: 0.61ms baseline: 0.49ms seq_len: 2048 slower: 1.18x kernel: 1.74ms baseline: 1.48ms seq_len: 4096 slower: 1.00x kernel: 5.78ms baseline: 5.78ms seq_len: 8192 slower: 0.89x kernel: 21.24ms baseline: 23.87ms
CUDA_VISIBLE_DEVICES=5 python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.90x kernel: 1.04ms baseline: 1.15ms seq_len: 256 slower: 0.96x kernel: 1.07ms baseline: 1.11ms seq_len: 512 slower: 1.49x kernel: 1.80ms baseline: 1.21ms seq_len: 1024 slower: 1.31x kernel: 3.95ms baseline: 3.01ms seq_len: 2048 slower: 1.42x kernel: 14.15ms baseline: 9.96ms seq_len: 4096 slower: 1.33x kernel: 50.86ms baseline: 38.35ms seq_len: 8192 slower: 1.33x kernel: 199.28ms baseline: 150.14ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.85x kernel: 0.97ms baseline: 1.14ms seq_len: 256 slower: 0.65x kernel: 0.74ms baseline: 1.15ms seq_len: 512 slower: 1.05x kernel: 1.23ms baseline: 1.17ms seq_len: 1024 slower: 2.03x kernel: 2.10ms baseline: 1.03ms seq_len: 2048 slower: 1.94x kernel: 5.06ms baseline: 2.61ms seq_len: 4096 slower: 1.88x kernel: 17.67ms baseline: 9.39ms seq_len: 8192 slower: 1.89x kernel: 65.58ms baseline: 34.75ms
CUDA_VISIBLE_DEVICES=5 python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.79x kernel: 1.41ms baseline: 1.79ms seq_len: 256 slower: 0.89x kernel: 1.37ms baseline: 1.54ms seq_len: 512 slower: 0.99x kernel: 1.82ms baseline: 1.84ms seq_len: 1024 slower: 0.70x kernel: 3.58ms baseline: 5.08ms seq_len: 2048 slower: 0.58x kernel: 10.10ms baseline: 17.42ms seq_len: 4096 slower: 0.52x kernel: 35.62ms baseline: 68.30ms seq_len: 8192 slower: 0.51x kernel: 134.71ms baseline: 266.61ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.64x kernel: 1.06ms baseline: 1.65ms seq_len: 256 slower: 0.75x kernel: 1.31ms baseline: 1.75ms seq_len: 512 slower: 0.94x kernel: 1.65ms baseline: 1.75ms seq_len: 1024 slower: 1.04x kernel: 2.16ms baseline: 2.08ms seq_len: 2048 slower: 0.72x kernel: 4.43ms baseline: 6.14ms seq_len: 4096 slower: 0.56x kernel: 13.03ms baseline: 23.37ms seq_len: 8192 slower: 0.49x kernel: 45.64ms baseline: 92.45ms
CUDA DEVICE 6
CUDA_VISIBLE_DEVICES=6 python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.01x kernel: 0.24ms baseline: 0.24ms seq_len: 256 slower: 1.69x kernel: 0.49ms baseline: 0.29ms seq_len: 512 slower: 1.20x kernel: 0.63ms baseline: 0.52ms seq_len: 1024 slower: 1.08x kernel: 1.50ms baseline: 1.39ms seq_len: 2048 slower: 1.05x kernel: 4.72ms baseline: 4.51ms seq_len: 4096 slower: 0.94x kernel: 16.88ms baseline: 17.87ms seq_len: 8192 slower: 0.85x kernel: 63.48ms baseline: 74.57ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.81x kernel: 0.23ms baseline: 0.28ms seq_len: 256 slower: 0.98x kernel: 0.29ms baseline: 0.30ms seq_len: 512 slower: 1.19x kernel: 0.35ms baseline: 0.29ms seq_len: 1024 slower: 1.25x kernel: 0.61ms baseline: 0.49ms seq_len: 2048 slower: 1.18x kernel: 1.76ms baseline: 1.49ms seq_len: 4096 slower: 1.00x kernel: 5.77ms baseline: 5.78ms seq_len: 8192 slower: 0.89x kernel: 21.26ms baseline: 23.95ms
CUDA_VISIBLE_DEVICES=6 python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 1.00x kernel: 1.05ms baseline: 1.04ms seq_len: 256 slower: 0.98x kernel: 1.10ms baseline: 1.12ms seq_len: 512 slower: 1.57x kernel: 1.83ms baseline: 1.17ms seq_len: 1024 slower: 1.31x kernel: 3.95ms baseline: 3.01ms seq_len: 2048 slower: 1.42x kernel: 14.14ms baseline: 9.95ms seq_len: 4096 slower: 1.33x kernel: 50.85ms baseline: 38.31ms seq_len: 8192 slower: 1.33x kernel: 199.27ms baseline: 149.85ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.87x kernel: 0.90ms baseline: 1.03ms seq_len: 256 slower: 0.80x kernel: 0.92ms baseline: 1.15ms seq_len: 512 slower: 1.44x kernel: 1.07ms baseline: 0.74ms seq_len: 1024 slower: 2.00x kernel: 2.14ms baseline: 1.07ms seq_len: 2048 slower: 1.92x kernel: 5.02ms baseline: 2.61ms seq_len: 4096 slower: 1.88x kernel: 17.66ms baseline: 9.39ms seq_len: 8192 slower: 1.90x kernel: 65.91ms baseline: 34.74ms
CUDA_VISIBLE_DEVICES=6 python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.79x kernel: 1.36ms baseline: 1.73ms seq_len: 256 slower: 1.04x kernel: 1.04ms baseline: 1.00ms seq_len: 512 slower: 0.94x kernel: 1.76ms baseline: 1.87ms seq_len: 1024 slower: 0.68x kernel: 3.44ms baseline: 5.07ms seq_len: 2048 slower: 0.58x kernel: 10.11ms baseline: 17.40ms seq_len: 4096 slower: 0.52x kernel: 35.61ms baseline: 68.31ms seq_len: 8192 slower: 0.51x kernel: 134.66ms baseline: 266.42ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.78x kernel: 1.27ms baseline: 1.64ms seq_len: 256 slower: 0.64x kernel: 0.91ms baseline: 1.43ms seq_len: 512 slower: 0.96x kernel: 1.35ms baseline: 1.41ms seq_len: 1024 slower: 1.11x kernel: 2.31ms baseline: 2.08ms seq_len: 2048 slower: 0.75x kernel: 4.62ms baseline: 6.17ms seq_len: 4096 slower: 0.56x kernel: 12.95ms baseline: 23.27ms seq_len: 8192 slower: 0.50x kernel: 45.78ms baseline: 92.20ms
CUDA DEVICE 7
CUDA_VISIBLE_DEVICES=7 python benchmark.py --only-forwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 2.33x kernel: 0.57ms baseline: 0.25ms seq_len: 256 slower: 1.26x kernel: 0.37ms baseline: 0.29ms seq_len: 512 slower: 1.24x kernel: 0.65ms baseline: 0.52ms seq_len: 1024 slower: 1.06x kernel: 1.50ms baseline: 1.42ms seq_len: 2048 slower: 1.05x kernel: 4.72ms baseline: 4.51ms seq_len: 4096 slower: 0.95x kernel: 16.86ms baseline: 17.83ms seq_len: 8192 slower: 0.78x kernel: 63.40ms baseline: 81.01ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.85x kernel: 0.23ms baseline: 0.27ms seq_len: 256 slower: 0.94x kernel: 0.28ms baseline: 0.30ms seq_len: 512 slower: 1.45x kernel: 0.43ms baseline: 0.30ms seq_len: 1024 slower: 1.26x kernel: 0.61ms baseline: 0.48ms seq_len: 2048 slower: 1.18x kernel: 1.75ms baseline: 1.49ms seq_len: 4096 slower: 1.00x kernel: 5.77ms baseline: 5.77ms seq_len: 8192 slower: 0.89x kernel: 21.28ms baseline: 23.84ms
CUDA_VISIBLE_DEVICES=7 python benchmark.py --only-backwards
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.91x kernel: 0.98ms baseline: 1.09ms seq_len: 256 slower: 1.06x kernel: 1.24ms baseline: 1.17ms seq_len: 512 slower: 1.59x kernel: 1.92ms baseline: 1.20ms seq_len: 1024 slower: 1.31x kernel: 3.94ms baseline: 3.01ms seq_len: 2048 slower: 1.42x kernel: 14.15ms baseline: 9.95ms seq_len: 4096 slower: 1.33x kernel: 50.94ms baseline: 38.29ms seq_len: 8192 slower: 1.33x kernel: 199.43ms baseline: 150.01ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.83x kernel: 0.84ms baseline: 1.01ms seq_len: 256 slower: 0.89x kernel: 0.75ms baseline: 0.85ms seq_len: 512 slower: 1.33x kernel: 1.13ms baseline: 0.85ms seq_len: 1024 slower: 1.95x kernel: 2.05ms baseline: 1.05ms seq_len: 2048 slower: 1.94x kernel: 5.08ms baseline: 2.61ms seq_len: 4096 slower: 1.89x kernel: 17.72ms baseline: 9.39ms seq_len: 8192 slower: 1.92x kernel: 66.60ms baseline: 34.74ms
CUDA_VISIBLE_DEVICES=7 python benchmark.py --causal
float32 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.72x kernel: 1.30ms baseline: 1.80ms seq_len: 256 slower: 0.57x kernel: 1.01ms baseline: 1.76ms seq_len: 512 slower: 1.02x kernel: 1.89ms baseline: 1.84ms seq_len: 1024 slower: 0.74x kernel: 3.75ms baseline: 5.08ms seq_len: 2048 slower: 0.58x kernel: 10.13ms baseline: 17.41ms seq_len: 4096 slower: 0.52x kernel: 35.70ms baseline: 68.27ms seq_len: 8192 slower: 0.51x kernel: 135.06ms baseline: 266.42ms
float16 batch: 4 heads: 8 dim 64
seq_len: 128 slower: 0.85x kernel: 0.80ms baseline: 0.94ms seq_len: 256 slower: 0.82x kernel: 0.95ms baseline: 1.16ms seq_len: 512 slower: 0.98x kernel: 1.36ms baseline: 1.39ms seq_len: 1024 slower: 1.07x kernel: 2.23ms baseline: 2.08ms seq_len: 2048 slower: 0.73x kernel: 4.49ms baseline: 6.15ms seq_len: 4096 slower: 0.56x kernel: 13.00ms baseline: 23.34ms seq_len: 8192 slower: 0.50x kernel: 45.80ms baseline: 92.32ms
@lucidrains The benchmarks for 8 different A100 (80 GB) devices are listed above. I made sure I tried a different host and each GPU was idle and no memory was being consumed on anything else.
Do these results look more normal to you? Hopefully, the larger sample size of tests helps!
@conceptofmind very helpful, thank you!