Model creation breaks on dynamic input
🐛 Bug
The following PR is being tested to improve dynamic shape functionality on simple models on TPU v3.
I see the error below at y_pred = model(x_test) when x_test = torch.nonzero(x_test) is included in the code sequence - without which the code runs normally.
https://symbolize.stripped_domain/r/?trace=7f400b6b400b,7f400b6b408f,7f400b49d23f&map=
*** SIGABRT received by PID 393548 (TID 393548) on cpu 25 from PID 393548; stack trace: ***
PC: @ 0x7f400b6b400b (unknown) raise
@ 0x7f3ea567f2d4 1120 (unknown)
@ 0x7f400b6b4090 919423072 (unknown)
@ 0x7f400b49d240 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f400b6b400b,7f3ea567f2d3,7f400b6b408f,7f400b49d23f&map=13dcea8db75a2ed88ce9356603629dcb:7f3e95b07000-7f3ea591c9e0
E1020 01:32:56.563561 393548 coredump_hook.cc:395] RAW: Remote crash data gathering hook invoked.
E1020 01:32:56.563569 393548 coredump_hook.cc:441] RAW: Skipping coredump since rlimit was 0 at process start.
E1020 01:32:56.563576 393548 client.cc:243] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1020 01:32:56.563580 393548 coredump_hook.cc:502] RAW: Sending fingerprint to remote end.
E1020 01:32:56.563585 393548 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E1020 01:32:56.563592 393548 coredump_hook.cc:506] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E1020 01:32:56.563595 393548 coredump_hook.cc:580] RAW: Discarding core.
E1020 01:32:56.819510 393548 process_state.cc:775] RAW: Raising signal 6 with default behavior
Aborted (core dumped)
CC @Krovatkin @vanbasten23
Digging deeper, it turns out the fully connected call in the forward() function leads to this failure.
def forward(self, x):
hidden = self.fc1(x)
#relu = self.relu(hidden)
#output = self.fc2(relu)
#output = self.sigmoid(output)
return hidden
Commenting out everything in this function and calling return x doesn't cause the error above.
can you run resnet50 without error? For segfault I think we need to gdb
Turns out the torch.nn.Sigmoid and torch.nn.ReLU work ok (i.e. don't cause a crash). It's basically the torch.nn.Linear that leads to this issue.
Confirming I can run the model successfully without x_test = torch.nonzero(x_test) on TPU and CPU backends.
Below is the outcome of pdb.set_trace()-ing the torch.nn.Linear call in forward(). What I see in this trace suggests the dispatcher call is the last python layer call before the crash. Seems like I should jump into the C++ stack.
(Pdb) s
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(339)__getattr__()
-> r = self.py_kernels.get(key, key)
(Pdb) s
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(340)__getattr__()
-> setattr(self, attr, r)
(Pdb) s
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(341)__getattr__()
-> return r
(Pdb) s
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(341)__getattr__()-><DispatchKey....Autograd: 119>
-> return r
(Pdb) s
--Call--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(305)__getattr__()
-> def __getattr__(self, attr):
(Pdb) s
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(306)__getattr__()
-> if len(attr) == 0 or not attr[0].isupper():
(Pdb) r
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(341)__getattr__()-><DispatchKey.Autograd: 118>
-> return r
(Pdb) s
--Call--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(305)__getattr__()
-> def __getattr__(self, attr):
(Pdb) s
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(306)__getattr__()
-> if len(attr) == 0 or not attr[0].isupper():
(Pdb) r
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(341)__getattr__()-><DispatchKey....aceOrView: 22>
-> return r
(Pdb) s
--Call--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(305)__getattr__()
-> def __getattr__(self, attr):
(Pdb) s
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(306)__getattr__()
-> if len(attr) == 0 or not attr[0].isupper():
(Pdb) r
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(341)__getattr__()-><DispatchKey.XLA: 47>
-> return r
(Pdb) s
--Call--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(305)__getattr__()
-> def __getattr__(self, attr):
(Pdb) r
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(341)__getattr__()-><DispatchKey.Autograd: 118>
-> return r
(Pdb) s
--Call--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(305)__getattr__()
-> def __getattr__(self, attr):
(Pdb) r
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(341)__getattr__()-><DispatchKey.XLA: 47>
-> return r
(Pdb) s
--Call--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(489)__getattr__()
-> def __getattr__(self, op_name):
(Pdb) r
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(520)__getattr__()-><OpOverloadPa...ten._to_cpu')>
-> return opoverloadpacket
(Pdb) s
--Call--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(391)__getattr__()
-> def __getattr__(self, key):
(Pdb) r
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(429)__getattr__()-><OpOverload(o...ad='default')>
-> return overload
(Pdb) s
--Call--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(305)__getattr__()
-> def __getattr__(self, attr):
(Pdb) r
--Return--
> /usr/local/lib/python3.8/dist-packages/torch/_ops.py(341)__getattr__()-><DispatchKey.XLA: 47>
-> return r
(Pdb) s
https://symbolize.stripped_domain/r/?trace=7fd607f1100b,7fd607f1108f,7fd607cfa23f&map=
*** SIGABRT received by PID 512418 (TID 512418) on cpu 28 from PID 512418; stack trace: ***
PC: @ 0x7fd607f1100b (unknown) raise
@ 0x7fd4a1edc2d4 1120 (unknown)
@ 0x7fd607f11090 196934768 (unknown)
@ 0x7fd607cfa240 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7fd607f1100b,7fd4a1edc2d3,7fd607f1108f,7fd607cfa23f&map=13dcea8db75a2ed88ce9356603629dcb:7fd492364000-7fd4a21799e0
E1020 03:35:27.486869 512418 coredump_hook.cc:395] RAW: Remote crash data gathering hook invoked.
E1020 03:35:27.486877 512418 coredump_hook.cc:441] RAW: Skipping coredump since rlimit was 0 at process start.
E1020 03:35:27.486885 512418 client.cc:243] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1020 03:35:27.486888 512418 coredump_hook.cc:502] RAW: Sending fingerprint to remote end.
E1020 03:35:27.486894 512418 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E1020 03:35:27.486901 512418 coredump_hook.cc:506] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E1020 03:35:27.486904 512418 coredump_hook.cc:580] RAW: Discarding core.
E1020 03:35:27.792265 512418 process_state.cc:775] RAW: Raising signal 6 with default behavior
Aborted (core dumped)
Doing further debugging using gdb.
Turns out the CPU fallback call in XLANativeFunctions::addmm causes this segfault. See the stack trace below.
In this scenario, I assume mat1 is dynamic. @Krovatkin is there a case against calling at::native::call_fallback_fn on dynamic inputs?
Thread 1 "python3" received signal SIGABRT, Aborted.
__GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
50 ../sysdeps/unix/sysv/linux/raise.c: No such file or directory.
(gdb) bt
#0 __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1 0x00007ffff7bf7859 in __GI_abort () at abort.c:79
#2 0x00007ffff79f1212 in ?? () from /lib/x86_64-linux-gnu/libunwind.so.8
#3 0x00007fffe9b447f5 in at::native::_call_fallback_fn<&torch_xla::xla_cpu_fallback, at::_ops::addmm, false, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&)>::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) (args=..., args=..., args=..., args=..., args=...) at /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/include/ATen/native/CPUFallback.h:36
#4 0x000000004987f6e0 in ?? ()
#5 0x00007fffffffc678 in ?? ()
#6 0x000000004987f6f0 in ?? ()
#7 0x00007fffffffc6a0 in ?? ()
#8 0x00007fffe9b06176 in torch_xla::XLANativeFunctions::addmm (self=..., mat1=..., mat2=..., beta=..., alpha=...) at /pytorch/xla/torch_xla/csrc/aten_xla_type.cpp:621
#9 0x00007fffe9bc8c6b in at::(anonymous namespace)::(anonymous namespace)::wrapper__addmm (self=..., mat1=..., mat2=..., beta=..., alpha=...) at /pytorch/xla/torch_xla/csrc/generated/RegisterXLA.cpp:514
#10 c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper__addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >::operator()(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) (this=<optimized out>, args=...,
args=..., args=..., args=..., args=...) at /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13
#11 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper__addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) (functor=<optimized out>, args=..., args=..., args=..., args=..., args=...)
at /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:461
#12 c10::impl::call_functor_with_args_from_stack_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper__addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >, false, 0ul, 1ul, 2ul, 3ul, 4ul, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&>(c10::OperatorKernel*, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<std::vector> >*, std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul, 4ul>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&>*) (functor=<optimized out>, dispatchKeySet=..., stack=0x7fffffffc820) at /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:496
#13 c10::impl::call_functor_with_args_from_stack<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper__addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >, false>(c10::OperatorKernel*, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<std::vector> >*) (functor=<optimized out>, dispatchKeySet=..., stack=0x7fffffffc820) at /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:511
#14 c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper__addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) (functor=<optimized out>, dispatchKeySet=..., stack=0x7fffffffc820) at /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:581
#15 0x00007ffec8a80b6a in c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#16 0x00007ffec8a73403 in (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const ()
from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#17 0x00007ffebeec733f in c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#18 0x00007ffebedcf0fb in at::_ops::addmm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#19 0x00007ffec0c09fb4 in torch::autograd::VariableType::(anonymous namespace)::addmm(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) ()
from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#20 0x00007ffec0c08d84 in c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &torch::autograd::VariableType::(anonymous namespace)::addmm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#21 0x00007ffec8a80b6a in c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#22 0x00007ffec8a73403 in (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const ()
from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#23 0x00007ffebeec733f in c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#24 0x00007ffebedcee5e in at::_ops::addmm::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#25 0x00007ffebe5bfa58 in at::native::linear(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#26 0x00007ffebf955c20 in c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper__linear>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#27 0x00007ffec8a80b6a in c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#28 0x00007ffec8a73403 in (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const ()
from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#29 0x00007ffebee9999f in c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#30 0x00007ffebeda0f98 in at::_ops::linear::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#31 0x00007ffec88fb765 in torch::autograd::THPVariable_linear(_object*, _object*, _object*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#32 0x00000000005f6929 in PyCFunction_Call ()
#33 0x00000000005f74f6 in _PyObject_MakeTpCall ()
#34 0x0000000000571164 in _PyEval_EvalFrameDefault ()
#35 0x00000000005f6cd6 in _PyFunction_Vectorcall ()
#36 0x000000000050bc2c in ?? ()
#37 0x00000000005f6082 in PyObject_Call ()
#38 0x000000000056d2d5 in _PyEval_EvalFrameDefault ()
#39 0x0000000000569dba in _PyEval_EvalCodeWithName ()
#40 0x00000000005f6eb3 in _PyFunction_Vectorcall ()
#41 0x000000000059d81e in ?? ()
#42 0x00000000005f74f6 in _PyObject_MakeTpCall ()
#43 0x0000000000571164 in _PyEval_EvalFrameDefault ()
#44 0x00000000005f6cd6 in _PyFunction_Vectorcall ()
#45 0x000000000050bc2c in ?? ()
#46 0x00000000005f6082 in PyObject_Call ()
#47 0x000000000056d2d5 in _PyEval_EvalFrameDefault ()
#48 0x0000000000569dba in _PyEval_EvalCodeWithName ()
#49 0x00000000005f6eb3 in _PyFunction_Vectorcall ()
#50 0x000000000059d81e in ?? ()
#51 0x00000000005f74f6 in _PyObject_MakeTpCall ()
#52 0x0000000000570d55 in _PyEval_EvalFrameDefault ()
#53 0x0000000000569dba in _PyEval_EvalCodeWithName ()
#54 0x00000000006902a7 in PyEval_EvalCode ()
#55 0x000000000067f951 in ?? ()
#56 0x000000000067f9cf in ?? ()
#57 0x000000000067fa71 in ?? ()
#58 0x0000000000681b97 in PyRun_SimpleFileExFlags ()
#59 0x00000000006b9d32 in Py_RunMain ()
#60 0x00000000006ba0bd in Py_BytesMain ()
#61 0x00007ffff7bf9083 in __libc_start_main (main=0x4efd60 <main>, argc=2, argv=0x7fffffffe3e8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffe3d8) at ../csu/libc-start.c:308
#62 0x00000000005fc5fe in _start ()
Looks like dynamic tensor causes a crash in _call_fallback_fn. It looks like symint is supported in this method. I wonder if the root cause is on the pytorch side or pytorch/xla. @Krovatkin @wconstab let me know what you think.
To be clear, I see this failure only when running the model on the PyTorch/XLA:TPU device (not the PyTorch/XLA:CPU device)
so it must be https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L618 where we fallback.
- I am not sure why it only crashed on TPU. fallback here is backend agnositc
- I am not sure if we have to fallback here. I felt like we can still lower it to something other than dot..
Nothing really jumps out to me from the stack trace. I think you'll have to debug the crash to find out which variable/tensor/etc. was literally causing the segv; once you get that info maybe we can trace that backwards in the stack and come up with a theory for where it went wrong.
Turns out after the changes made in this PR, I continue to observe the cpu fallback code path. This suggests the main cause of the CPU fallback is the following condition: beta.to<double>() != 1 || alpha.to<double>() != 1). The strange thing is looking at torch.nn.Linear implementations shows alpha and beta are in fact set to their default value of 1 (see the addmm calls in this method).
__GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
50 ../sysdeps/unix/sysv/linux/raise.c: No such file or directory.
(gdb) bt
#0 __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1 0x00007ffff7bf6859 in __GI_abort () at abort.c:79
#2 0x00007ffff7de41d2 in ?? () from /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
#3 0x00007ffff7de5ca9 in ?? () from /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
#4 0x00007ffebea6f51d in at::_ops::addmm::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#5 0x00007fffffffce60 in ?? ()
#6 0x00000000054c3a40 in ?? ()
#7 0x00007fffffffcfd0 in ?? ()
#8 0x00000000054c3a50 in ?? ()
#9 0x00007ffebe25b9f8 in at::native::linear(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#10 0x00007ffebf60ecb0 in c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper__linear>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#11 0x00007ffec8a7d3aa in c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#12 0x00007ffec8a6fcd3 in (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const ()
from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#13 0x00007ffebeb3a2df in c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#14 0x00007ffebea41548 in at::_ops::linear::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#15 0x00007ffec88f6823 in torch::autograd::THPVariable_linear(_object*, _object*, _object*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#16 0x00000000005f6929 in PyCFunction_Call ()
#17 0x00000000005f74f6 in _PyObject_MakeTpCall ()
#18 0x0000000000571164 in _PyEval_EvalFrameDefault ()
#19 0x00000000005f6cd6 in _PyFunction_Vectorcall ()
#20 0x000000000050bc2c in ?? ()
#21 0x00000000005f6082 in PyObject_Call ()
#22 0x000000000056d2d5 in _PyEval_EvalFrameDefault ()
#23 0x0000000000569dba in _PyEval_EvalCodeWithName ()
#24 0x00000000005f6eb3 in _PyFunction_Vectorcall ()
#25 0x000000000059d81e in ?? ()
#26 0x00000000005f74f6 in _PyObject_MakeTpCall ()
#27 0x0000000000571164 in _PyEval_EvalFrameDefault ()
#28 0x00000000005f6cd6 in _PyFunction_Vectorcall ()
#29 0x000000000050bc2c in ?? ()
#30 0x00000000005f6082 in PyObject_Call ()
#31 0x000000000056d2d5 in _PyEval_EvalFrameDefault ()
#32 0x0000000000569dba in _PyEval_EvalCodeWithName ()
#33 0x00000000005f6eb3 in _PyFunction_Vectorcall ()
#34 0x000000000059d81e in ?? ()
#35 0x00000000005f74f6 in _PyObject_MakeTpCall ()
#36 0x0000000000570d55 in _PyEval_EvalFrameDefault ()
#37 0x0000000000569dba in _PyEval_EvalCodeWithName ()
#38 0x00000000006902a7 in PyEval_EvalCode ()
#39 0x000000000067f951 in ?? ()
#40 0x000000000067f9cf in ?? ()
#41 0x000000000067fa71 in ?? ()
#42 0x0000000000681b97 in PyRun_SimpleFileExFlags ()
#43 0x00000000006b9d32 in Py_RunMain ()
#44 0x00000000006ba0bd in Py_BytesMain ()
#45 0x00007ffff7bf8083 in __libc_start_main (main=0x4efd60 <main>, argc=2, argv=0x7fffffffe3b8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffe3a8) at ../csu/libc-start.c:308
#46 0x00000000005fc5fe in _start ()
Hi, Milad @miladm , is it ok to assign this issue to you?