dgl icon indicating copy to clipboard operation
dgl copied to clipboard

Support of bfloat16 data type

Open ufimtsev opened this issue 3 years ago • 5 comments

🚀 Feature

Please add support of bfloat16 data type. bfloat16 operations are fast when natively supported (starting with Ampere architecture), require less memory than 32-bit floats, and most importantly, training with bfloat16 data type is way more stable than training with fp16.

ufimtsev avatar Aug 04 '22 23:08 ufimtsev

Thanks for the suggestion. Just be curious. Does PyTorch support bfloat16 natively?

jermainewang avatar Aug 08 '22 04:08 jermainewang

Thanks for the suggestion. Just be curious. Does PyTorch support bfloat16 natively?

Yes. PyTorch supports bfloat16 for both CPU and GPU.

yaox12 avatar Aug 08 '22 05:08 yaox12

One needs to have a GPU with bfloat16 support in the hardware. In my case torch.cuda.is_bf16_supported() returns true when run on RTX3090 and false on Titan V. Tested with pytorch 1.12.0

ufimtsev avatar Aug 08 '22 21:08 ufimtsev

bfloat16 requires compute capability >= 8.0 and CUDA >= 11.

yaox12 avatar Aug 09 '22 01:08 yaox12

Do we need fallbacks for __CUDA_ARCH__ < 800? cc @nv-dlasalle

For PyTorch,

  1. bf16 arithmetic functions are supported on all CUDA architectures. For example, the following code is valid.
In [1]: import torch

In [2]: torch.cuda.is_bf16_supported()
Out[2]: False

In [3]: c = torch.ones(10, dtype=torch.bfloat16, device='cuda')

In [4]: l = torch.nn.Linear(10, 10).to(torch.bfloat16).to('cuda')

In [5]: l(c)
Out[5]:
tensor([ 0.1250, -0.2461, -0.2559,  0.3047, -0.2871, -0.0059, -0.7812, -1.1641,
         0.7148, -0.0771], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<AddBackward0>)
  1. AMP doesn't support bf16 for __CUDA_ARCH__ < 800.
In [1]: import torch

In [2]: torch.cuda.is_bf16_supported()
Out[2]: False

In [3]: torch.cuda.amp.autocast(dtype=torch.bfloat16)
---------------------------------------------------------------------------
RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.

yaox12 avatar Sep 21 '22 06:09 yaox12