xla icon indicating copy to clipboard operation
xla copied to clipboard

[torchbench] `hf_GPT2` (large, too) fails to run on `bfloat16` dtype.

Open ysiraichi opened this issue 1 year ago • 1 comments

🐛 Bug

After converting the hf_GPT2 (and its large variation) model to bfloat16 and running it (see command below), it fails with the following error:

python xla/benchmarks/experiment_runner.py \
    --suite-name torchbench --accelerator cuda \
    --xla PJRT --dynamo None --test eval \
    --no-resume --print-subprocess \
    -k hf_GPT2
2024-02-10 02:55:58.205631: F ./torch_xla/csrc/runtime/debug_macros.h:20] Non-OK-status: status.status() status: INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %add.281 = f32[16384,768]{1,0} add(f32[16384,768]{1,0} %dot.276, bf16[16384,768]{1,0} %broadcast.280), but mixed precision is disallowed.
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > ConsumeValue<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >(absl::lts_20230802::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >&&)
        torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
        torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
        torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, bool)




        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault


        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        PyEval_EvalCodeEx
        PyEval_EvalCode



        PyRun_SimpleFileExFlags
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

To Reproduce

  • Run on: #6518

Affected Configurations

  • Non-Dynamo Inference
  • Dynamo Inference
  • Dynamo Training

Environment

  • Reproducible on XLA backend [CPU/TPU]: CUDA
  • torch_xla version: 408b3766bcff0a85bd486f232d972f7b1b3663cc

cc @miladm

ysiraichi avatar Feb 12 '24 19:02 ysiraichi

Both hf_GPT2 and hf_GPT2_large seem to be failing with this same error in the following cases:

  • Non-dynamo: inference
  • Dynamo: inference + training

ysiraichi avatar Feb 16 '24 14:02 ysiraichi