pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

AdamW(fused=True) slower than unfused AdamW

Open ad8e opened this issue 11 months ago • 21 comments

🐛 Describe the bug

512M parameters Mostly vanilla LM transformer. FlashAttention 2.4.2, PyTorch 2.2.0. Uses both FA and FlashRotary. Dtype: bf16 Nvidia A40. single-GPU

Unfused: 85 TFLOPS Fused: 68 TFLOPS

Versions

PyTorch version: 2.2.0 Is debug build: False CUDA used to build PyTorch: 12.2 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.28.1 Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.19.17-coreweave-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.2.140 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A40 Nvidia driver version: 525.147.05 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 96 On-line CPU(s) list: 0-95 Vendor ID: AuthenticAMD Model name: AMD EPYC 7413 24-Core Processor CPU family: 25 Model: 1 Thread(s) per core: 2 Core(s) per socket: 24 Socket(s): 2 Stepping: 1 Frequency boost: enabled CPU max MHz: 3630.8101 CPU min MHz: 1500.0000 BogoMIPS: 5299.52 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm Virtualization: AMD-V L1d cache: 1.5 MiB (48 instances) L1i cache: 1.5 MiB (48 instances) L2 cache: 24 MiB (48 instances) L3 cache: 256 MiB (8 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-23,48-71 NUMA node1 CPU(s): 24-47,72-95 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] torch==2.2.0 [pip3] torchaudio==2.2.0 [pip3] torchvision==0.17.0 [pip3] triton==2.2.0 [conda] Could not collect

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar

ad8e avatar Mar 13 '24 22:03 ad8e

hmm this is concerning. Do you mind sharing a description of the sizes of your parameters?

and whether you were using AMP/needing a gradscaler? I would guess no since you’re using bf16 and not fp16, but wanted to rule out possible culprits

janeyx99 avatar Mar 13 '24 22:03 janeyx99

n_layer: 12 n_head: 12 kv_heads: 6 (GQA) hidden_dim: 1536 n_tokens: 2048 (context length) vocab_dim: 65536 activation: "swiglu"

No AMP/gradscaler. If a profile would help, I can produce one.

ad8e avatar Mar 13 '24 22:03 ad8e

a profile would be helpful yes!

janeyx99 avatar Mar 14 '24 01:03 janeyx99

These trace.json files were gigantic (multi-GB), so here's smaller versions without stack information and on two steps only:

https://drive.google.com/file/d/1c2CQST_U_Qf6O1qgSXr0DKVorytoNgp5/view?usp=sharing

https://drive.google.com/file/d/1NCi6frLbXfVhL0pzdmwoRzmYLtFK-qeQ/view?usp=sharing

Torch.compile mode is the default (not "reduce-overhead") and my TFLOPS are constant starting from the second step, so skipping just 1 step of warmup seems ok.

The program is being launched by torchrun on a 1-GPU machine.

ad8e avatar Mar 14 '24 02:03 ad8e

Turns out the step count isn't an issue with stack information disabled. Here's 3 steps:

Unfused: https://drive.google.com/file/d/1OKuJVC3PPK5vn6-TWdi1vPD8bAoHP7h4/view?usp=sharing

Fused: https://drive.google.com/file/d/1CUeKkwiOMEHRJaNZlC65utq6WuS7Fxh4/view?usp=sharing

ad8e avatar Mar 14 '24 03:03 ad8e

Oh! Wait a moment...are you comparing torch.compile(step) vs step of AdamW(fused=True)? torch.compile(step) does the vertical fusion by default so it being faster is reasonable (and the point haha)!

If you're comparing eager (non compiled) step between (fused=True) and the default (), then the fused should be faster.

janeyx99 avatar Mar 14 '24 16:03 janeyx99

No: torch.compile is wrapping forward and loss, but backward and AdamW are outside the compile. So it's: torch.compile(forward+loss but not backward+AdamW) + un-fused AdamW vs torch.compile(forward+loss but not backward+AdamW) + fused AdamW

# only this part is compiled
def fused_forward_and_loss(input_ids, labels):
    logits = model.forward(input_ids)
    return F.cross_entropy(logits.view(-1, logits.shape[-1]).float(), labels.view(-1))

fused_forward_and_loss = torch.compile(fused_forward_and_loss)

...
loss = fused_forward_and_loss(input_ids, labels)
loss.backward()
opt.step()

ad8e avatar Mar 14 '24 16:03 ad8e

Ah, thanks for that, I was looking for the optimizer stuff in the torch compiled region and getting confused.

I did look at both the traces and you're right the fused kernel is slower. In detail:

The default AdamW uses a series of foreach ops, and the part in the trace this corresponds to is: image

Whereas the fused AdamW is a foreach_add (step update) followed by one single fused kernel, and the part in the trace this corresponds to is: image

cc @crcrpar There are some concerning observations:

  1. While the fused implementation correctly launches fewer kernels (one fused_adamw instead of a bunch of foreach_* kernels), each kernel is so much slower! Each fused_adamw kernel takes ~6.5ms whereas a single foreach_* kernel is only 0.2ms (30 times slower!) Why is the fused op so much slower?
  2. Whereas there are 137 kernels launched per foreach op, there are 75 kernels launched for the fused (due to the multi_tensor_apply chunking logic). For the same number of tensor inputs, I'm surprised there are more kernels launched in the smaller foreach ops. Do you know what accounts for this?

janeyx99 avatar Mar 14 '24 17:03 janeyx99

pytorch v2.2.0 doesn't seem to have #117872, 2.2.1, either by looking at https://github.com/pytorch/pytorch/blob/v2.2.1/aten/src/ATen/native/cuda/fused_adam_utils.cuh. So I guess a nightly would be worth trying

crcrpar avatar Mar 14 '24 18:03 crcrpar

Instructions for installing nightly can be found in the toggle system of https://pytorch.org/

janeyx99 avatar Mar 14 '24 18:03 janeyx99

My CUDA version is 12.2, and PyTorch only lists 12.1 in its nightlies. Is that still ok for me to install?

ad8e avatar Mar 14 '24 18:03 ad8e

Yes, I think so

janeyx99 avatar Mar 14 '24 20:03 janeyx99

It broke when trying to run with PyTorch Nightly, mismatched CUDA versions.

/home/kevin/.local/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
  warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/kevin/.local/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
Traceback (most recent call last):
  File "/mnt/clusterstorage/workspace/kevin/basedformer/basedformer/models/neoxflash.py", line 22, in <module>
    from flash_attn import (
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: libc10_cuda.so: cannot open shared object file: No such file or directory
...
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: libc10_cuda.so: cannot open shared object file: No such file or directory
E0314 20:23:26.930000 139961313477760 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 5073) of binary: /usr/bin/python3

Debugging this isn't your issue though.

Updating from CW's stable to CoreWeave's nightly images a few days ago (http://ghcr.io/coreweave/ml-containers/nightly-torch-extras:999655b-nccl-2024.03.11.05-cuda12.2.2-ubuntu22.04-nccl2.19.3-1-torch2.3.0a0-vision0.18.0a0-audio2.2.0a0-flash_attn2.4.2) produced NaNs in backward, but that also isn't your issue. (I also can't switch my container easily.)

It seems I don't have a good way to install PyTorch nightly on my current system.

ad8e avatar Mar 14 '24 20:03 ad8e

Here's PyTorch nightly profiles (CUDA 12.2, 2.3.0a0+3eb322f). torch.compile broke, so I turned it off for forward+loss. Performance is much closer; unfused is only a little bit faster than fused.

https://drive.google.com/file/d/1k8zoSGeK7Pr5jst_MU4u6HJS6lhikUcl/view?usp=sharing https://drive.google.com/file/d/1weyCO-EAnnte0rJ45qFEvNR1UYpYmTE0/view?usp=sharing

ad8e avatar Mar 15 '24 18:03 ad8e

Thank you for getting the traces! It is surprising that the fused step is still slower than the foreach...but a look at the trace shows that we only get 67% occupancy per fused adam kernel and 100% occupancy for each foreach op kernel. @crcrpar why could this be the case?

Unfused, 100% occupancy image

Fused, 67% occupancy image

The fused kernel does use almost twice as many registers, but it might just be due to it doing more work, but it shouldn't be doing more work than the sum of the foreach ops.

btw @yifuwang's occupancy changes (with dynamic chunking) will help make both impls faster.

janeyx99 avatar Mar 15 '24 20:03 janeyx99

As a side note, I tried @torch.compile on PyTorch's un-fused AdamW, after you mentioned it. TFLOPS goes from 85 to 4, haha. It's because torch.compile keeps recompiling when the LR changes (which is inevitable with a learning rate scheduler).

ad8e avatar Mar 22 '24 06:03 ad8e

Haha thank you for gathering the number there (though it's abysmal :() Is this true even when you wrap lr into a tensor? like torch.tensor(lr)?

There is a tracking issue for lr: https://github.com/pytorch/pytorch/issues/120934 that is in progress cc @mlazos

janeyx99 avatar Mar 23 '24 06:03 janeyx99

Using AdamW(lr=torch.tensor(...)) with scheduler active, TFLOPS went from 85 to 88. Zoom zoom!

LR scheduler is LambdaLR with regular non-tensor floats.

ad8e avatar Mar 23 '24 06:03 ad8e

icy χατγιρλ: @ad8e I can reproduce the performance dropoff under torch 2.2 icy χατγιρλ: compiling only forward + loss, enabling the fused optimizer completely destroys the performance icy χατγιρλ: but on 2.4 nightly, enabling fused helps a tiny bit

ad8e avatar Apr 07 '24 17:04 ad8e

@ad8e just a heads up there was a regression in fused optimizers in the past few days. https://github.com/pytorch/pytorch/pull/123566 should fix it.

we only get 67% occupancy per fused adam kernel and 100% occupancy for each foreach op kernel

@janeyx99 the foreach version materializes intermediate results into global memory at every step. The fused version only materializes the final results, which requires more registers per thread and can lead to lower occupancy. I think it's normal for the fused version to be faster while having lower occupancy, because it avoids doing a lot of expensive memory I/O.

yifuwang avatar Apr 09 '24 00:04 yifuwang

PyTorch nightly: fused AdamW 69.5 TFLOPS unfused AdamW 70.5 TFLOPS

One of the two fused AdamW runs has a wobbly TFLOPS line, going up and down around 69.5. Screenshot 2024-04-30 at 11 45 12 PM

PyTorch version: 2.4.0a0+1bcbc91
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.2
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.19.17-coreweave-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A40
Nvidia driver version: 535.161.07
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   48 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          96
On-line CPU(s) list:             0-95
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC 7413 24-Core Processor
CPU family:                      25
Model:                           1
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       2
Stepping:                        1
Frequency boost:                 enabled
CPU max MHz:                     3630.8101
CPU min MHz:                     1500.0000
BogoMIPS:                        5299.59
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                  AMD-V
L1d cache:                       1.5 MiB (48 instances)
L1i cache:                       1.5 MiB (48 instances)
L2 cache:                        24 MiB (48 instances)
L3 cache:                        256 MiB (8 instances)
NUMA node(s):                    8
NUMA node0 CPU(s):               0-5,48-53
NUMA node1 CPU(s):               6-11,54-59
NUMA node2 CPU(s):               12-17,60-65
NUMA node3 CPU(s):               18-23,66-71
NUMA node4 CPU(s):               24-29,72-77
NUMA node5 CPU(s):               30-35,78-83
NUMA node6 CPU(s):               36-41,84-89
NUMA node7 CPU(s):               42-47,90-95
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0a0+1bcbc91
[pip3] torchaudio==2.2.0a0+ea437b3
[pip3] torchvision==0.19.0a0+06ad737
[pip3] triton==3.0.0
[conda] Could not collect

ad8e avatar May 01 '24 06:05 ad8e