lightning-thunder
lightning-thunder copied to clipboard
Different behavior in torch.full: checks if `fill_value` can be cast to `dtype`.
🐛 Bug
The check function in https://github.com/Lightning-AI/lightning-thunder/blob/019557e1cbd2944a4d8d719954ee8a7c295c539b/thunder/core/prims.py#L2650-L2655 is not the same as the one used in torch.full (https://github.com/pytorch/pytorch/blob/f5e704a6f25939478f770f8980c344ab461f0113/c10/util/Half.h#L462)
To Reproduce
An example of inconsistency:
import torch
import thunder
def func():
return torch.full((1,), 0.5, device = "cuda", dtype=torch.bool)
o =func()
print(o)
jfunc = thunder.jit(func)
o1=jfunc()
print(o1)
Torch outputs: tensor([True], device='cuda:0')
Thunder throws error RuntimeError: Can't safely cast fill_value of numbertype <class 'float'> to dtype thunder.dtypes.bool8
Expected behavior
Expect Thunder to behave the same as Torch
>>> torch.full((2,2),0.5,dtype=torch.bool)
tensor([[True, True],
[True, True]])
>>> torch.full((2,2),0.5,dtype=torch.int)
tensor([[0, 0],
[0, 0]], dtype=torch.int32)
>>> torch.full((2,2),-1,dtype=torch.uint32)
tensor([[4294967295, 4294967295],
[4294967295, 4294967295]], dtype=torch.uint32)
>>> torch.full((2,2),sys.maxsize,dtype=torch.uint32)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: value cannot be converted to type uint32_t without overflow
>>> x=torch.full((1,2), 1j, dtype=torch.float)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: value cannot be converted to type float without overflow
Additional context
As a follow-up to https://github.com/Lightning-AI/lightning-thunder/pull/949
cc @apaz-cli