Assert on empty PJRT buffers
Simple CR to avoid a segmentation fault when there are placeholder tensors involved, as we attempt to de-reference the device from the buffer. It fixes the seg fault for https://github.com/pytorch/xla/issues/9049.
ERROR:root:Caught an exception when exiting the process. Exception:
Traceback (most recent call last):
File "/ansible/pytorch/xla/torch_xla/__init__.py", line 209, in _prepare_to_exit
_XLAC._prepare_to_exit()
RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:741 : Check failed: pjrt_data->buffer != nullptr
*** Begin stack trace ***
tsl::CurrentStackTrace[abi:cxx11]()
torch_xla::runtime::PjRtComputationClient::ExecuteComputation(torch_xla::runtime::ComputationClient::Computation const&, absl::lts_20230802::Span<std::shared_ptr<torch_xla::runtime::ComputationClient::Data> const>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, torch_xla::runtime::ComputationClient::ExecuteComputationOptions const&)
torch::lazy::MultiWait::Complete(std::function<void ()> const&)
std::function<void ()>::operator()() const
tsl::thread::EigenEnvironment::ExecuteTask(tsl::thread::EigenEnvironment::Task const&)
Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::WorkerLoop(int)
Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::ThreadPoolTempl(int, bool, tsl::thread::EigenEnvironment)::{lambda()#1}::operator()() const
void std::__invoke_impl<void, Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::ThreadPoolTempl(int, bool, tsl::thread::EigenEnvironment)::{lambda()#1}&>(std::__invoke_other, Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::ThreadPoolTempl(int, bool, tsl::thread::EigenEnvironment)::{lambda()#1}&)
std::enable_if<is_invocable_r_v<void, Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::ThreadPoolTempl(int, bool, tsl::thread::EigenEnvironment)::{lambda()#1}&>, std::enable_if>::type std::__invoke_r<void, Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::ThreadPoolTempl(int, bool, tsl::thread::EigenEnvironment)::{lambda()#1}&>(void&&, (Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::ThreadPoolTempl(int, bool, tsl::thread::EigenEnvironment)::{lambda()#1}&)...)
std::_Function_handler<void (), Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::ThreadPoolTempl(int, bool, tsl::thread::EigenEnvironment)::{lambda()#1}>::_M_invoke(std::_Any_data const&)
std::function<void ()>::operator()() const
tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}::operator()() const
void std::__invoke_impl<void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(std::__invoke_other, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&)
std::__invoke_result<tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>::type std::__invoke<tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(std::__invoke_result&&, (tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&)...)
std::invoke_result<tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>::type std::invoke<tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(std::invoke_result&&, (tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&)...)
void absl::lts_20230802::internal_any_invocable::InvokeR<void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&, , void>(tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&)
void absl::lts_20230802::internal_any_invocable::RemoteInvoker<false, void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(absl::lts_20230802::internal_any_invocable::TypeErasedState*)
absl::lts_20230802::internal_any_invocable::Impl<void ()>::operator()()
Is it possible to add a test for this?
Looking at #9049 quickly, I didn't follow the path how calling _get_xla_tensors_hlo leads to ExecuteComputation call on the PjRtClient
Looking at https://github.com/pytorch/xla/issues/9049 quickly, I didn't follow the path how calling _get_xla_tensors_hlo leads to ExecuteComputation call on the PjRtClient
Good question, that'll need to be explained as part of the issue above. It was unclear to us as well, as we came across it.
I am happy to add one. Is the request to add a test to check that _get_xla_tensors_hlo with placeholders hits an assert basically? Usually, this sort of checks don't require a standalone test, as these invariants only exist to capture those particular undesired cases. In this case, it gives us a stacktrace and error message instead of a seg fault.
Is the request to add a test to check that
_get_xla_tensors_hlowith placeholders hits an assert basically?
Yeah. That would be good enough!
Usually, this sort of checks don't require a standalone test, as these invariants only exist to capture those particular undesired cases.
While I agree with you, I think this test is important because it provides us with these 2 guarantees:
- The assertion actually triggers the behavior we expect: an error instead of a segfault
- This behavior stays through time