xla
xla copied to clipboard
[torchbench] `timm_nfnet` training fails to run on AMP precision.
🐛 Bug
After starting to run training benchmarks on AMP (#6518), timm_nfnet
fails with the following error:
2024-02-27 05:47:24.611656: F ./torch_xla/csrc/runtime/debug_macros.h:20] Non-OK-status: status.status() status: IN TERNAL: during context [Unknown]: Seen floating point types of different precisions in %multiply.67 = f32[128,3072,6,6]{3,2,1,0} multiply(f16[128,3072,6,6]{3,2,1,0} %multiply.44, f32[128,3072,6,6]{3,2,1,0} %add.66), but mixed prec ision is disallowed.
*** Begin stack trace ***
tsl::CurrentStackTrace[abi:cxx11]()
std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > ConsumeValue<st d::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >(absl::lts_20230802::Sta tusOr<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >&&)
torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::Compi leInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intr usive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10 ::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, absl::lts_20230802::Span<std::__cx x11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::Sync TensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::alloc ator<torch::lazy::Value> > const&)
torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor,c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xl a::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Spa n<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExe cutor::SyncTensorsConfig const&, bool)
torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::det ail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call
torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&)
torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&)
torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&)
torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)
torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)
clone
*** End stack trace
Affected Configurations
- Dynamo Training
Environment
- Reproducible on XLA backend [CPU/TPU]: CUDA
- torch_xla version: 4327d24a3a9443b120ba04079a92cb5773e738cc
Additional Context
This error looks somewhat similar to the one in hf_GPT2
(#6521). It looks like the arguments of an HLO operation are of different type. Indicating a possible problem with AMP.
cc @miladm @JackCaoG