Support of bfloat16 data type
🚀 Feature
Please add support of bfloat16 data type. bfloat16 operations are fast when natively supported (starting with Ampere architecture), require less memory than 32-bit floats, and most importantly, training with bfloat16 data type is way more stable than training with fp16.
Thanks for the suggestion. Just be curious. Does PyTorch support bfloat16 natively?
Thanks for the suggestion. Just be curious. Does PyTorch support bfloat16 natively?
Yes. PyTorch supports bfloat16 for both CPU and GPU.
One needs to have a GPU with bfloat16 support in the hardware. In my case torch.cuda.is_bf16_supported() returns true when run on RTX3090 and false on Titan V. Tested with pytorch 1.12.0
bfloat16 requires compute capability >= 8.0 and CUDA >= 11.
Do we need fallbacks for __CUDA_ARCH__ < 800? cc @nv-dlasalle
For PyTorch,
- bf16 arithmetic functions are supported on all CUDA architectures. For example, the following code is valid.
In [1]: import torch
In [2]: torch.cuda.is_bf16_supported()
Out[2]: False
In [3]: c = torch.ones(10, dtype=torch.bfloat16, device='cuda')
In [4]: l = torch.nn.Linear(10, 10).to(torch.bfloat16).to('cuda')
In [5]: l(c)
Out[5]:
tensor([ 0.1250, -0.2461, -0.2559, 0.3047, -0.2871, -0.0059, -0.7812, -1.1641,
0.7148, -0.0771], device='cuda:0', dtype=torch.bfloat16,
grad_fn=<AddBackward0>)
- AMP doesn't support bf16 for
__CUDA_ARCH__ < 800.
In [1]: import torch
In [2]: torch.cuda.is_bf16_supported()
Out[2]: False
In [3]: torch.cuda.amp.autocast(dtype=torch.bfloat16)
---------------------------------------------------------------------------
RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.