xla
xla copied to clipboard
[torchbench] `vision_maskrcnn` failing on inference with dynamo after `bfloat16` conversion.
🐛 Bug
After converting the vision_maskrcnn
model to bfloat16
, running inference on dynamo raises the following error:
python xla/benchmarks/experiment_runner.py \
--suite-name torchbench --accelerator cuda \
--xla PJRT --dynamo openxla --test eval \
--no-resume --print-subprocess \
-k vision_maskrcnn
F0000 00:00:1708093149.740541 20316 debug_macros.h:20] Non-OK-status: status.status() status: INVALID_ARGUMENT: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.
*** Begin stack trace ***
tsl::CurrentStackTrace[abi:cxx11]()
xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)
torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
torch_xla::IndexGet::IndexGet(torch::lazy::Value const&, torch::lazy::Value const&, long)
torch_xla::IndexByTensors(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, absl::lts_20230802::Span<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::
XLATensor> > const>, long)
torch_xla::tensor_methods::index(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, absl::lts_20230802::Span<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torc
h_xla::XLATensor> > const>, long)
torch_xla::XLANativeFunctions::index(at::Tensor const&, c10::List<std::optional<at::Tensor> > const&)
at::_ops::index_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::List<std::optional<at::Tensor> > const&)
at::_ops::index_Tensor::call(at::Tensor const&, c10::List<std::optional<at::Tensor> > const&)
torch::autograd::THPVariable_getitem(_object*, _object*)
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyObject_FastCallDict
_PyObject_Call_Prepend
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyObject_FastCallDict
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCodeEx
PyEval_EvalCode
PyRun_SimpleFileExFlags
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
*** Check failure stack trace: ***
@ 0x7f33961176a9 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7f338c1ffde9 ConsumeValue<>()
@ 0x7f338c1ffe3e torch_xla::ShapeHelper::ShapeOfXlaOp()
@ 0x7f338c173e5c torch_xla::InferOutputShape()
@ 0x7f338c16a18d torch_xla::(anonymous namespace)::NodeOutputShape()
@ 0x7f338c16a213 std::_Function_handler<>::_M_invoke()
@ 0x7f338c1f3b96 torch_xla::XlaNode::GetOpShape()
@ 0x7f338c1f4469 torch_xla::XlaNode::XlaNode()
@ 0x7f338c16a38d torch_xla::IndexGet::IndexGet()
@ 0x7f338c171fc8 torch_xla::IndexByTensors()
@ 0x7f338beea872 torch_xla::tensor_methods::index()
@ 0x7f338be29be4 torch_xla::XLANativeFunctions::index()
@ 0x7f338c0edc1f c10::impl::make_boxed_from_unboxed_functor<>::call()
@ 0x7f348a089d9b (anonymous namespace)::functionalizeFallback()
@ 0x7f348ade0e79 at::_ops::index_Tensor::redispatch()
@ 0x55bd4ff82f10 (unknown)
Environment
- Reproducible on XLA backend [CPU/TPU]: CUDA
- torch_xla version: 20692cb04256ca24b1dd2c7d00ffba0bb0cb69ac
cc @miladm @JackCaoG