lightning-thunder
lightning-thunder copied to clipboard
thunder.jit does not work for autocast transform in some cases
🐛 Bug
In processing #198, thunder.jit fails in thunder/tests/test_autocast.py for this case
To Reproduce
import torch
import thunder
from thunder.core.transforms import autocast
def h(a, b, c):
return (a @ b) + c
x, y, z = (torch.randn((2, 2), device='cuda', dtype=torch.float32) for _ in range(3))
jfunc = thunder.jit(autocast(h, dtype=thunder.bfloat16))
jfunc(x,y,z)
Error msg:
Traceback (most recent call last):
File "/wayan/lightning-thunder/thunder/../mytests/autocast.py", line 11, in <module>
jfunc(x,y,z)
File "/wayan/lightning-thunder/thunder/__init__.py", line 661, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "/wayan/lightning-thunder/thunder/__init__.py", line 277, in cache_info_wrapper
res = fn(*args, **kwargs)
File "/wayan/lightning-thunder/thunder/__init__.py", line 622, in get_computation_and_inputs
extraces = transform_for_execution(
File "/wayan/lightning-thunder/thunder/common.py", line 581, in transform_for_execution
extrace = executors.passes.transform_for_execution(dce_trace, executors_list)
File "/wayan/lightning-thunder/thunder/executors/passes.py", line 139, in transform_for_execution
extrace = _transform_for_operator_executor_execution(trace, executors_list)
File "/wayan/lightning-thunder/thunder/executors/passes.py", line 112, in _transform_for_operator_executor_execution
extrace = transforms.visitor_transform(trace, visit_)
File "/wayan/lightning-thunder/thunder/core/transforms.py", line 368, in visitor_transform
visit_type = visit(bsym)
File "/wayan/lightning-thunder/thunder/executors/passes.py", line 97, in visit_
result: None | bool = visit_helper_(bsym)
File "/wayan/lightning-thunder/thunder/executors/passes.py", line 80, in visit_helper_
out = op(*bsym.args, **bsym.kwargs)
File "/wayan/lightning-thunder/thunder/core/symbol.py", line 257, in __call__
result = self.meta(*args, **kwargs)
File "/wayan/lightning-thunder/thunder/core/symbol.py", line 257, in __call__
result = self.meta(*args, **kwargs)
File "/wayan/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
result = fn(*args, **kwargs)
File "/wayan/lightning-thunder/thunder/torch/__init__.py", line 2519, in matmul
return prims.matmul(a, b)
File "/wayan/lightning-thunder/thunder/core/symbol.py", line 253, in __call__
result = self.meta(*args, **kwargs)
File "/wayan/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
result = fn(*args, **kwargs)
File "/wayan/lightning-thunder/thunder/core/prims.py", line 3432, in matmul_meta
utils.check(
File "/wayan/lightning-thunder/thunder/core/baseutils.py", line 103, in check
raise exception_type(s())
RuntimeError: Expected a.dtype=float32 and b.dtype=bfloat16 to be the same
Here is some debugging information that might help
The trace contains tree = prims.convert_element_type(a, dtypes.bfloat16) # tree: "cuda:0 f32[2, 2]" after _apply_trace_proxy_rename using GeneralJitCtx._proxy_swapmap
> /wayan/lightning-thunder/thunder/core/jit_ext.py(1566)thunder_general_jit()
-> computation_trace = _apply_trace_proxy_rename(computation_trace, ctx._proxy_swapmap, "computation")
(Pdb) l
1561
1562 # Update prologue trace by renaming proxies which are passed from prologue to the computation trace
1563 prologue_trace = _apply_trace_proxy_rename(prologue_trace, restrict_proxy_swapmap(pro_to_comp_proxies))
1564
1565 # Update computation trace by renaming proxies which are in the ctx._proxy_swapmap
1566 -> computation_trace = _apply_trace_proxy_rename(computation_trace, ctx._proxy_swapmap, "computation")
1567
1568 # Update epilogue trace by renaming proxies which are passed to the epilogue trace from prologue and computation traces
1569 if epilogue_trace:
1570 epilogue_trace = _apply_trace_proxy_rename(
1571 epilogue_trace, restrict_proxy_swapmap(pro_to_epi_proxies + comp_to_epi_proxies), "epilogue"
(Pdb) p computation_trace
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1, t_2):
# t_0: "cuda:0 f32[2, 2]"
# t_1: "cuda:0 f32[2, 2]"
# t_2: "cuda:0 f32[2, 2]"
t1 = prims.convert_element_type(t_0, dtypes.bfloat16) # t1: "cuda:0 bf16[2, 2]"
t2 = prims.convert_element_type(t_1, dtypes.bfloat16) # t2: "cuda:0 bf16[2, 2]"
t3 = prims.matmul(t1, t2) # t3: "cuda:0 bf16[2, 2]"
t5 = ltorch.add(t3, t_2, alpha=None) # t5: "cuda:0 f32[2, 2]"
# t4 = prims.convert_element_type(t3, dtypes.float32) # t4: "cuda:0 f32[2, 2]"
# t5 = prims.add(t4, t_2) # t5: "cuda:0 f32[2, 2]"
return t5
(Pdb) tmp_map=list(ctx._proxy_swapmap.items())
(Pdb) tmp_map
[(t_0, a), (t_1, b), (t0, x), (t3, result), (t1, tree)]
(Pdb) tmp_map[4][0].proxy.dtype
float32
(Pdb) tmp_map[4][1].dtype
float32
(Pdb) n
> /wayan/lightning-thunder/thunder/core/jit_ext.py(1569)thunder_general_jit()
-> if epilogue_trace:
(Pdb) p computation_trace
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b, t_2):
# a: "cuda:0 f32[2, 2]"
# b: "cuda:0 f32[2, 2]"
# t_2: "cuda:0 f32[2, 2]"
tree = prims.convert_element_type(a, dtypes.bfloat16) # tree: "cuda:0 f32[2, 2]"
t2 = prims.convert_element_type(b, dtypes.bfloat16) # t2: "cuda:0 bf16[2, 2]"
result = prims.matmul(tree, t2) # result: "cuda:0 bf16[2, 2]"
t5 = ltorch.add(result, t_2, alpha=None) # t5: "cuda:0 f32[2, 2]"
# t4 = prims.convert_element_type(result, dtypes.float32) # t4: "cuda:0 f32[2, 2]"
# t5 = prims.add(t4, t_2) # t5: "cuda:0 f32[2, 2]"
return t5
(Pdb)
cc @crcrpar