cupy icon indicating copy to clipboard operation
cupy copied to clipboard

Support for bfloat16

Open philippwitte opened this issue 1 year ago • 8 comments

Description

Are there plans to support the bfloat16 data type in the near future? This data type is becoming increasingly popular in LLM training. It looks like currently it's not supported. I.e., calling y = cp.asarray(x), where x is a torch tensor of type torch.bfloat16, returns "TypeError: Got unsupported ScalarType BFloat16". Are there any recommended workarounds in the meantime?

Additional Information

No response

philippwitte avatar Apr 21 '23 02:04 philippwitte

Curious where/how would you use bf16 if CuPy were to support it? Any pointer or reference? Thanks! 🙂

leofang avatar Apr 21 '23 02:04 leofang

It would be good if numpy data type extensions à la https://github.com/jax-ml/ml_dtypes/tree/main were supported, which includes bfloat16, fp8 etc.

jglaser avatar Sep 12 '23 23:09 jglaser

Seconding this! bfloat16 and fp8 support are important for my use case. I'd love to see these.

guberti avatar Sep 18 '23 00:09 guberti

Any progress on this? We really need it for LLM training and inference.

wuxibin89 avatar Sep 25 '23 06:09 wuxibin89

bfloat16 support is sorely missed in cupy. Would really appreciate it getting fixed! We are currently forced to work around it like this (thankfully we have torch.view):

x = torch.arange(10, dtype=torch.bfloat16, device="cuda")
print(x)
# tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], device='cuda:0',
#        dtype=torch.bfloat16)

# view as uint8
y = x.view(dtype=torch.uint8)

array_size_in_bytes = y.nelement()*y.element_size()
mem = cupy.cuda.UnownedMemory(y.data_ptr(), array_size_in_bytes, owner=None)
memptr = cupy.cuda.MemoryPointer(mem, offset=0)
arr = cupy.ndarray(y.size(), dtype=cupy.uint8, memptr=memptr)
out = torch.as_tensor(arr, device=x.device, dtype=torch.uint8)
print(out)
# tensor([  0,   0, 128,  63,   0,  64,  64,  64, 128,  64, 160,  64, 192,  64,
#         224,  64,   0,  65,  16,  65], device='cuda:0', dtype=torch.uint8)

# view as bfloat16 again
out = out.view(x.dtype)
print(out)
# tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], device='cuda:0',
#        dtype=torch.bfloat16)

borisfom avatar Mar 28 '24 07:03 borisfom

I see (in https://github.com/cupy/cupy/issues/8269) that the bfloat16 feature is planned for v14 release. @asi1024 , is there a WIP branch that others can play with or help if needed?

yuanlin2004 avatar Apr 29 '24 20:04 yuanlin2004

Would also love this!

dakofler avatar Aug 06 '24 18:08 dakofler