apex icon indicating copy to clipboard operation
apex copied to clipboard

Large Performance Regression with FusedAdam

Open Kaixhin opened this issue 1 year ago • 3 comments

I'm running some reinforcement learning experiments and noticed a large performance regression when swapping out PyTorch modules for APEX ones. I've narrowed it down to FusedAdam being used to replace AdamW, as can be seen in the plots below.

Optimiser Train Test
optim.AdamW pre_apex_train_returns pre_apex_test_returns
optimizers.FusedAdam train_returns test_returns

Relevant code snippets below are a switch at the beginning of my code, setting the optimisers for all my component models, and the training loop:

 if importlib.util.find_spec('apex') and torch.cuda.is_available():  # Use FusedAdam if NVIDIA Apex available           
   from apex import optimizers                                                                                          
   Adam = optimizers.FusedAdam  # Implements AdamW weight decay by default                                              
   Adam.zero_grad = lambda *args, **kwargs: None  # Patch the APEX optimiser to match the standard PyTorch API (set grad to None is done on init by default in APEX)
 else:                                                                                                                  
   Adam = optim.AdamW
encoder_optimiser = Adam(encoder.parameters(), lr=training_cfg.learning_rate, weight_decay=training_cfg.weight_decay)
critic_optimiser = Adam(critic.parameters(), lr=training_cfg.learning_rate, weight_decay=training_cfg.weight_decay)
 encoder_optimiser.zero_grad(set_to_none=True)                                                                   
 critic_optimiser.zero_grad(set_to_none=True)                                                                    
 value_loss.backward()                                                                                                 
 encoder_optimiser.step()                                                                                        
 critic_optimiser.step()

I would expect FusedAdam to act as a drop-in faster version of AdamW, especially with lr and weight_decay set. For reference the experiments above use a learning rate of 0.0003 and a weight decay of 0.


Environment information below (APEX was built on commit 1d77111):

PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:56:21)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-47-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] botorch==0.6.4
[pip3] gpytorch==1.8.0
[pip3] numpy==1.21.5
[pip3] torch==1.12.1
[pip3] torchaudio==0.12.1
[pip3] torchvision==0.13.1
[conda] blas                      1.0                         mkl  
[conda] botorch                   0.6.4                         0    pytorch
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] gpytorch                  1.8.0              pyhd8ed1ab_0    conda-forge
[conda] libblas                   3.9.0            12_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            12_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            12_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            12_linux64_mkl    conda-forge
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.21.5           py39h6c91a56_3  
[conda] numpy-base                1.21.5           py39ha15fc14_3  
[conda] pytorch                   1.12.1          py3.9_cuda11.3_cudnn8.3.2_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                0.12.1               py39_cu113    pytorch
[conda] torchvision               0.13.1               py39_cu113    pytorch

Kaixhin avatar Oct 05 '22 05:10 Kaixhin

@crcrpar just recently upstreamed FusedAdam in https://github.com/pytorch/pytorch/pull/85739 and was seeing a speedup (we had to resubmit it a few times, so you might need to check the previous PRs for more information). Could you check the native implementation and compare it to your runs?

ptrblck avatar Oct 05 '22 05:10 ptrblck

Seems to be working fine so far - I installed the nightly build of PyTorch to test (had to download manually from conda because of a network error). Was this a valid way to test? For reference, here are the changed parts of my environment:

PyTorch version: 1.13.0.dev20221004
CUDA used to build PyTorch: 11.7

CUDA runtime version: 11.7.99

Versions of relevant libraries:
[pip3] torch==1.13.0.dev20221004
[pip3] torchaudio==0.13.0.dev20221004
[pip3] torchvision==0.15.0.dev20221004
[conda] pytorch                   1.13.0.dev20221004 py3.9_cuda11.7_cudnn8.5.0_0    <unknown>
[conda] pytorch-cuda              11.7                 h67b0de4_0    pytorch-nightly
[conda] torchaudio                0.13.0.dev20221004      py39_cu117    pytorch-nightly
[conda] torchvision               0.15.0.dev20221004      py39_cu117    pytorch-nightly

Note that I switched back to PyTorch LayerNorm for this experiment rather than rebuilding APEX, but in my prior experiments there was no noticeable difference between the two.

Kaixhin avatar Oct 05 '22 11:10 Kaixhin

Results are in - looks (largely) fine. Possible regression in performance but with RL but not one I'd consider significant and I would have to do several runs to check and each run takes me almost one day.

Edit: So something seems to be up with FusedAdam for me - there's basically no learning. I've rerun this and updated the table in the first comment. What would you suggest next?

PyTorch Version Train Test
1.12.1 py3.9_cuda11.3_cudnn8.3.2_0 pre_apex_train_returns pre_apex_test_returns
1.13.0.dev20221004 py3.9_cuda11.7_cudnn8.5.0_0 train_returns test_returns

Kaixhin avatar Oct 06 '22 03:10 Kaixhin

I've cleaned up my Python installation as much as possible, removed and reinstalled conda, and have tried again, this time with the PyTorch 1.13 stable release and Apex built at commit 6a40a0a. Results are still the same (cut the experiment after a few hours because the difference is clear):

Optimiser Train Test
optim.AdamW train_returns test_returns
optimizers.FusedAdam train_returns test_returns

New environment information:

PyTorch version: 1.13.0
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-50-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 520.61.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] botorch==0.7.2
[pip3] gpytorch==1.9.0
[pip3] numpy==1.23.4
[pip3] torch==1.13.0
[pip3] torchaudio==0.13.0
[pip3] torchvision==0.14.0
[conda] blas                      2.116                       mkl    conda-forge
[conda] blas-devel                3.9.0            16_linux64_mkl    conda-forge
[conda] botorch                   0.7.2              pyhd8ed1ab_0    conda-forge
[conda] gpytorch                  1.9.0              pyhd8ed1ab_0    conda-forge
[conda] libblas                   3.9.0            16_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            16_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            16_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            16_linux64_mkl    conda-forge
[conda] mkl                       2022.1.0           h84fe81f_915    conda-forge
[conda] mkl-devel                 2022.1.0           ha770c72_916    conda-forge
[conda] mkl-include               2022.1.0           h84fe81f_915    conda-forge
[conda] numpy                     1.23.4          py310h53a5b5f_1    conda-forge
[conda] pytorch                   1.13.0          py3.10_cuda11.7_cudnn8.5.0_0    pytorch
[conda] pytorch-cuda              11.7                 h67b0de4_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                0.13.0              py310_cu117    pytorch
[conda] torchvision               0.14.0              py310_cu117    pytorch

Kaixhin avatar Nov 01 '22 06:11 Kaixhin

Just tested pytorch 1.13 optim.Adam( , fused=True) it is slower than fused=False. (test done with a large transformer training) diff is about 5% slower.

vince62s avatar Nov 01 '22 12:11 vince62s

@vince62s this issue is about model performance (reward since this is an RL setup, but analagous to classification accuracy), not timing.

Kaixhin avatar Nov 02 '22 08:11 Kaixhin

Based on https://github.com/pytorch/pytorch/issues/88258, I did new runs to see if the Adam(fused=True) implementation in PyTorch 1.13 caused issues, or if it was specific to APEX's FusedAdam. Learning happens fine, so it does seem to be something happening in APEX.

PyTorch Adam Implementation Train Test
Default train_returns test_returns
Fused train_returns test_returns

Kaixhin avatar Nov 09 '22 04:11 Kaixhin