lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

thunder.jit does not work for autocast transform in some cases

Open kiya00 opened this issue 1 year ago • 0 comments

🐛 Bug

In processing #198, thunder.jit fails in thunder/tests/test_autocast.py for this case

To Reproduce

import torch
import thunder
from thunder.core.transforms import autocast

def h(a, b, c):
    return (a @ b) + c

x, y, z = (torch.randn((2, 2), device='cuda', dtype=torch.float32) for _ in range(3))
jfunc = thunder.jit(autocast(h, dtype=thunder.bfloat16))
jfunc(x,y,z)

Error msg:

Traceback (most recent call last):
  File "/wayan/lightning-thunder/thunder/../mytests/autocast.py", line 11, in <module>
    jfunc(x,y,z)
  File "/wayan/lightning-thunder/thunder/__init__.py", line 661, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/wayan/lightning-thunder/thunder/__init__.py", line 277, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/wayan/lightning-thunder/thunder/__init__.py", line 622, in get_computation_and_inputs
    extraces = transform_for_execution(
  File "/wayan/lightning-thunder/thunder/common.py", line 581, in transform_for_execution
    extrace = executors.passes.transform_for_execution(dce_trace, executors_list)
  File "/wayan/lightning-thunder/thunder/executors/passes.py", line 139, in transform_for_execution
    extrace = _transform_for_operator_executor_execution(trace, executors_list)
  File "/wayan/lightning-thunder/thunder/executors/passes.py", line 112, in _transform_for_operator_executor_execution
    extrace = transforms.visitor_transform(trace, visit_)
  File "/wayan/lightning-thunder/thunder/core/transforms.py", line 368, in visitor_transform
    visit_type = visit(bsym)
  File "/wayan/lightning-thunder/thunder/executors/passes.py", line 97, in visit_
    result: None | bool = visit_helper_(bsym)
  File "/wayan/lightning-thunder/thunder/executors/passes.py", line 80, in visit_helper_
    out = op(*bsym.args, **bsym.kwargs)
  File "/wayan/lightning-thunder/thunder/core/symbol.py", line 257, in __call__
    result = self.meta(*args, **kwargs)
  File "/wayan/lightning-thunder/thunder/core/symbol.py", line 257, in __call__
    result = self.meta(*args, **kwargs)
  File "/wayan/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
    result = fn(*args, **kwargs)
  File "/wayan/lightning-thunder/thunder/torch/__init__.py", line 2519, in matmul
    return prims.matmul(a, b)
  File "/wayan/lightning-thunder/thunder/core/symbol.py", line 253, in __call__
    result = self.meta(*args, **kwargs)
  File "/wayan/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
    result = fn(*args, **kwargs)
  File "/wayan/lightning-thunder/thunder/core/prims.py", line 3432, in matmul_meta
    utils.check(
  File "/wayan/lightning-thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
RuntimeError: Expected a.dtype=float32 and b.dtype=bfloat16 to be the same

Here is some debugging information that might help The trace contains tree = prims.convert_element_type(a, dtypes.bfloat16) # tree: "cuda:0 f32[2, 2]" after _apply_trace_proxy_rename using GeneralJitCtx._proxy_swapmap

> /wayan/lightning-thunder/thunder/core/jit_ext.py(1566)thunder_general_jit()
-> computation_trace = _apply_trace_proxy_rename(computation_trace, ctx._proxy_swapmap, "computation")
(Pdb) l
1561
1562        # Update prologue trace by renaming proxies which are passed from prologue to the computation trace
1563        prologue_trace = _apply_trace_proxy_rename(prologue_trace, restrict_proxy_swapmap(pro_to_comp_proxies))
1564
1565        # Update computation trace by renaming proxies which are in the ctx._proxy_swapmap
1566 ->     computation_trace = _apply_trace_proxy_rename(computation_trace, ctx._proxy_swapmap, "computation")
1567
1568        # Update epilogue trace by renaming proxies which are passed to the epilogue trace from prologue and computation traces
1569        if epilogue_trace:
1570            epilogue_trace = _apply_trace_proxy_rename(
1571                epilogue_trace, restrict_proxy_swapmap(pro_to_epi_proxies + comp_to_epi_proxies), "epilogue"
(Pdb) p computation_trace
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1, t_2):
  # t_0: "cuda:0 f32[2, 2]"
  # t_1: "cuda:0 f32[2, 2]"
  # t_2: "cuda:0 f32[2, 2]"
  t1 = prims.convert_element_type(t_0, dtypes.bfloat16)  # t1: "cuda:0 bf16[2, 2]"
  t2 = prims.convert_element_type(t_1, dtypes.bfloat16)  # t2: "cuda:0 bf16[2, 2]"
  t3 = prims.matmul(t1, t2)  # t3: "cuda:0 bf16[2, 2]"
  t5 = ltorch.add(t3, t_2, alpha=None)  # t5: "cuda:0 f32[2, 2]"
    # t4 = prims.convert_element_type(t3, dtypes.float32)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.add(t4, t_2)  # t5: "cuda:0 f32[2, 2]"
  return t5
(Pdb) tmp_map=list(ctx._proxy_swapmap.items())
(Pdb) tmp_map
[(t_0, a), (t_1, b), (t0, x), (t3, result), (t1, tree)]
(Pdb) tmp_map[4][0].proxy.dtype
float32
(Pdb) tmp_map[4][1].dtype
float32
(Pdb) n
> /wayan/lightning-thunder/thunder/core/jit_ext.py(1569)thunder_general_jit()
-> if epilogue_trace:
(Pdb) p computation_trace
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b, t_2):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[2, 2]"
  # t_2: "cuda:0 f32[2, 2]"
  tree = prims.convert_element_type(a, dtypes.bfloat16)  # tree: "cuda:0 f32[2, 2]"
  t2 = prims.convert_element_type(b, dtypes.bfloat16)  # t2: "cuda:0 bf16[2, 2]"
  result = prims.matmul(tree, t2)  # result: "cuda:0 bf16[2, 2]"
  t5 = ltorch.add(result, t_2, alpha=None)  # t5: "cuda:0 f32[2, 2]"
    # t4 = prims.convert_element_type(result, dtypes.float32)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.add(t4, t_2)  # t5: "cuda:0 f32[2, 2]"
  return t5
(Pdb)

cc @crcrpar

kiya00 avatar Apr 26 '24 12:04 kiya00