lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Different behavior in torch.full: checks if `fill_value` can be cast to `dtype`.

Open kiya00 opened this issue 1 year ago • 0 comments

🐛 Bug

The check function in https://github.com/Lightning-AI/lightning-thunder/blob/019557e1cbd2944a4d8d719954ee8a7c295c539b/thunder/core/prims.py#L2650-L2655 is not the same as the one used in torch.full (https://github.com/pytorch/pytorch/blob/f5e704a6f25939478f770f8980c344ab461f0113/c10/util/Half.h#L462)

To Reproduce

An example of inconsistency:

import torch
import thunder

def func():
    return torch.full((1,), 0.5, device = "cuda", dtype=torch.bool)
o =func()
print(o)

jfunc = thunder.jit(func)
o1=jfunc()
print(o1)

Torch outputs: tensor([True], device='cuda:0') Thunder throws error RuntimeError: Can't safely cast fill_value of numbertype <class 'float'> to dtype thunder.dtypes.bool8

Expected behavior

Expect Thunder to behave the same as Torch

>>> torch.full((2,2),0.5,dtype=torch.bool)
tensor([[True, True],
        [True, True]])
>>> torch.full((2,2),0.5,dtype=torch.int)
tensor([[0, 0],
        [0, 0]], dtype=torch.int32)
>>> torch.full((2,2),-1,dtype=torch.uint32)
tensor([[4294967295, 4294967295],
        [4294967295, 4294967295]], dtype=torch.uint32)
>>> torch.full((2,2),sys.maxsize,dtype=torch.uint32)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: value cannot be converted to type uint32_t without overflow
>>> x=torch.full((1,2), 1j, dtype=torch.float)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: value cannot be converted to type float without overflow

Additional context

As a follow-up to https://github.com/Lightning-AI/lightning-thunder/pull/949

cc @apaz-cli

kiya00 avatar Aug 12 '24 17:08 kiya00