xla
xla copied to clipboard
Dynamic batch dimension export tutorial (following https://openxla.org/stablehlo/tutorials/pytorch-export) segfaults (torch-xla==2.4.0)
🐛 Bug
I am trying to follow https://openxla.org/stablehlo/tutorials/pytorch-export verbatim.
I have taken the official python:3.10 Docker image, and inside it I have installed:
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.4.0
torchvision==0.19.0
torchaudio==2.4.0
torch-xla==2.4.0
tensorflow-cpu==2.16.2
This call from the tutorial (under Export with dynamic batch dimension):
dynamic_stablehlo = exported_program_to_stablehlo(dynamic_export)
Throws:
F0000 00:00:1727773477.564067 2760 debug_macros.h:20] Non-OK-status: status.status()
Status: INVALID_ARGUMENT: Non-broadcast dimensions must not be dynamic.
*** Begin stack trace ***
tsl::CurrentStackTrace()
xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
torch_xla::BuildMaxPoolNd(xla::XlaOp, long, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, bool)
torch_xla::MaxPoolNd::Lower(torch_xla::LoweringContext*) const
torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyRun_SimpleFileObject
_PyRun_AnyFileObject
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
*** Check failure stack trace: ***
@ 0x7effc17b15f9 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7effb9ce60e4 ConsumeValue<>()
@ 0x7effb9ce614e torch_xla::ShapeHelper::ShapeOfXlaOp()
@ 0x7effb997bc03 torch_xla::(anonymous namespace)::ComputeMaxPoolIndices()
@ 0x7effb997cff1 torch_xla::BuildMaxPoolNd()
@ 0x7effb9c5fd76 torch_xla::MaxPoolNd::Lower()
@ 0x7effb9cdfcfd torch_xla::LoweringContext::LowerNode()
@ 0x7effb9ce0683 torch_xla::LoweringContext::GetOutputOp()
@ 0x7effb9ce0981 torch_xla::LoweringContext::AddResult()
@ 0x7effb9975474 torch_xla::DumpUtil::ToHlo()
@ 0x7effb9ae31f6 torch_xla::XLAGraphExecutor::DumpHloComputation()
@ 0x7effb986a083 torch_xla::(anonymous namespace)::InitXlaModuleBindings()::{lambda()#67}::operator()()
@ 0x7effb988519f pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7effb986d1be pybind11::cpp_function::dispatcher()
@ 0x7f020cefc4fd cfunction_call
Aborted (core dumped)