tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] [Relay] data types float64 and float32 do not match in BroadcastRel

Open jikechao opened this issue 2 years ago • 1 comments
trafficstars

The PyTorch model with hardswish operator and input_dtype=float64 crashed when load to relay.

Actual behavior

data types float64 and float32 do not match in BroadcastRel
data types float64 and float32 do not match in BroadcastRel
Traceback (most recent call last):
  File "19_crash_hardswish.py", line 19, in <module>
    mod, params = relay.frontend.from_pytorch(trace, input_shapes)
  File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/pytorch.py", line 4970, in from_pytorch
    outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
  File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/pytorch.py", line 4243, in convert_operators
    self.record_output_type(relay_out)
  File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/pytorch.py", line 238, in record_output_type
    self.infer_type_with_prelude(output)
  File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/pytorch.py", line 174, in infer_type_with_prelude
    body = self.infer_type(val, self.prelude.mod)
  File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/pytorch.py", line 167, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/workplace/software/tvm/tvm/python/tvm/ir/transform.py", line 160, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/workplace/software/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 238, in __call__
    raise get_last_ffi_error()
tvm.error.DiagnosticError: Traceback (most recent call last):
  7: TVMFuncCall
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::$_6>(tvm::transform::$_6, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  5: tvm::transform::Pass::operator()(tvm::IRModule) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::DiagnosticContext::Render()
  0: _ZN3tvm7runtime6detail
  File "/workplace/software/tvm/tvm/src/ir/diagnostic.cc", line 131
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

What actually happened

Environment

Any environment details, such as: Operating System, TVM version, etc

Steps to reproduce

import torch
from tvm import relay
import tvm
import numpy as np
from torch.nn import Module

input_data = torch.randn([2, 4, 4], dtype=torch.float64)  # crash when dtype=float64
class hardswish(Module):
    def forward(self, *args):
        return torch.nn.functional.hardswish(args[0], )

m = hardswish().float().eval()

torch_outputs = m(input_data)

trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([2, 4, 4]))]

mod, params = relay.frontend.from_pytorch(trace, input_shapes)

Environments

  • TVM: 0.13.dev0'
  • PyTorch: 1.13.1+cu117

Triage

  • needs-triage
  • front:pytorch

Analysis

This bug is related to input_dtype. If the input_dtype was set as 'float32', TVM can run well. If the input_dtype was set as 'float64', TVM will crash as before.

cc @echuraev @Hzfengsy @shingjan

jikechao avatar Jun 01 '23 11:06 jikechao

The hardsigmoid, functional.normalize, and Linear operators have the same crash.

jikechao avatar Jul 03 '23 16:07 jikechao