xla
xla copied to clipboard
Lowering aten::where encounters mismatched types and crashed
🐛 Bug
Hi Pytorch/XLA developers,
I try to dump a Huggingface Opt model into Stablehlo with torch_xla. The model is exported by torch successfully but crashed when lowering the exported program to stablehlo at the Pytorch/XLA side. And it seems to be a bug in Pytorch/XLA's lowering of the aten::where
op.
Any help would be appreciated.
To Reproduce
Steps to reproduce the behavior:
- install the huggingface tranformers lib
pip install transformers==4.37.2
- run the following scripts
from transformers import OPTForCausalLM
import torch
from torch_xla import stablehlo
model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b")
model.eval()
input = torch.randint(0, 100, size=(1, 108))
exported_fn = torch.export.export(model, args=(input,))
# works fine until here
# print(exported_fn)
options = stablehlo.StableHLOExportOptions()
# options.override_tracing_arguments = (m_args,)
shlo = stablehlo.exported_program_to_stablehlo(exported_fn, options)
method = shlo._name_to_stablehlo["forward"]
ir_str = method.text
print(ir_str)
- see the output
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707223682.454894 2471766 cpu_client.cc:370] TfrtCpuClient created.
Traceback (most recent call last):
File "test.py", line 16, in <module>
shlo = stablehlo.exported_program_to_stablehlo(exported_fn, options)
File "/home/dev/miniconda3/envs/spu/lib/python3.8/site-packages/torch_xla/stablehlo.py", line 541, in exported_program_to_stablehlo
bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
File "/home/dev/miniconda3/envs/spu/lib/python3.8/site-packages/torch_xla/stablehlo.py", line 326, in _exported_program_to_stablehlo_bundle
stablehlo_content = xm.get_stablehlo_bytecode(res)
File "/home/dev/miniconda3/envs/spu/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 947, in get_stablehlo_bytecode
return torch_xla._XLAC._get_stablehlo(
RuntimeError: Error while lowering: [] aten::where, xla_shape=f32[1,1,108,108]{3,2,1,0}, dynamic_dims: ()
Error: torch_xla/csrc/helpers.cpp:623 : Check failed: xla::ShapeUtil::SameElementType(shape1, shape2)
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::XlaHelpers::PromoteShapes(xla::XlaOp, xla::XlaOp)
torch_xla::Generic::Lower(torch_xla::LoweringContext*) const
torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
PyCFunction_Call
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCodeEx
PyEval_EvalCode
PyRun_SimpleFileExFlags
Py_RunMain
Py_BytesMain
__libc_start_main
*** End stack trace ***
f32[1,1,108,108]{3,2,1,0} and f64[1,1,108,108]{3,2,1,0}
Frames:
I0000 00:00:1707223684.467343 2471766 cpu_client.cc:373] TfrtCpuClient destroyed.
From the log above, it's said the f32[1,1,108,108]{3,2,1,0}
and f64[1,1,108,108]{3,2,1,0}
operands from a where
op are mismatched.
I am confused where does the dtype of f64
come from? Because I did not find it in the exported torch program.
Expected behavior
The text of the stablehlo of this model's forward function was printed.
Environment
- Reproducible on XLA backend [CPU/TPU]: CPU
- torch_xla version: 2.2.0
- torch version: 2.2.0
- transformers version: 4.37.2
Additional context
After some digging, I found a related ISSUE and a fixing PR, but this PR was not merged.
Thanks for the additional context, due to the failure error when dump model to StableHLO:
RuntimeError: Error while lowering: [] aten::where, xla_shape=f32[1,1,108,108]{3,2,1,0}, dynamic_dims: ()
Error: torch_xla/csrc/helpers.cpp:623 : Check failed: xla::ShapeUtil::SameElementType(shape1, shape2)
do you have more context of this issue? @qihqi, and is it ok to assign this ticket to you?
Any progress? @qihqi @ManfeiBai I also encountered the same issue.
Hi, sorry for the delay.
Just tried on HEAD and your example is passing. My guess is that is is fixed after https://github.com/pytorch/xla/pull/6460
So we introduces f64 when scalars are used in binary ops (because python float is float64 on most platforms). The above PR fixes that.
Thx for the reply. So it will be released in torch-xla v2.2.1 ?
Same, when inference official Gemma xla Pytorch code version on TPUv4-8.
python run_xla.py --ckpt "/home/m0niusplus/gemma_pytorch/gemma-2b-pytorch/gemma-2b.ckpt"
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 261, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 210, in _process_chunk
return [fn(*args) for args in chunk]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 210, in <listcomp>
return [fn(*args) for args in chunk]
^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 77, in _run_thread_per_device
replica_results = list(
^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
yield _result_or_cancel(fs.pop())
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
return fut.result(timeout)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 70, in _thread_fn
return fn()
^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 176, in __call__
self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 185, in generate
xm.mark_step()
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 891, in mark_step
torch_xla._XLAC._xla_step_marker(
RuntimeError: Error while lowering: [] aten::view_as_complex_copy, xla_shape=bf16[1,1,6,128]{3,2,1,0}, dynamic_dims: ()
XLA builder error: UNIMPLEMENTED: Complex component type is not implemented.:
Frames:
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 267, in <module>
main(args)
File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 239, in main
xmp.spawn(
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 200, in spawn
run_multiprocess(spawn_fn, start_method=start_method)
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 160, in run_multiprocess
replica_results = list(
^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 161, in <genexpr>
itertools.chain.from_iterable(
^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 620, in _chain_from_iterable_of_lists
for element in iterable:
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
yield _result_or_cancel(fs.pop())
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
return fut.result(timeout)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
RuntimeError: Error while lowering: [] aten::view_as_complex_copy, xla_shape=bf16[1,1,6,128]{3,2,1,0}, dynamic_dims: ()
XLA builder error: UNIMPLEMENTED: Complex component type is not implemented.:
Frames:
With PT_XLA_DEBUG=1
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: user mark_step
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: 2f09bda1ca0ff6e0fbf3925676711b1
Compilation Analysis: Number of Graph Inputs: 3
Compilation Analysis: Number of Graph Outputs: 41
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: mark_step (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/core/xla_model.py:891)
Compilation Analysis: generate (/home/m0niusplus/gemma_pytorch/run_xla.py:156)
Compilation Analysis: __call__ (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py:176)
Compilation Analysis: _thread_fn (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py:70)
Compilation Analysis: run (/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py:58)
Compilation Analysis: _worker (/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py:83)
Compilation Analysis: run (/opt/dev/conda/envs/xla/lib/python3.11/threading.py:982)
Compilation Analysis: _bootstrap_inner (/opt/dev/conda/envs/xla/lib/python3.11/threading.py:1045)
Compilation Analysis: ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: user mark_step
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 2f09bda1ca0ff6e0fbf3925676711b1
Execution Analysis: Number of Graph Inputs: 3
Execution Analysis: Number of Graph Outputs: 41
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: mark_step (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/core/xla_model.py:891)
Execution Analysis: generate (/home/m0niusplus/gemma_pytorch/run_xla.py:156)
Execution Analysis: __call__ (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py:176)
Execution Analysis: _thread_fn (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py:70)
Execution Analysis: run (/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py:58)
Execution Analysis: _worker (/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py:83)
Execution Analysis: run (/opt/dev/conda/envs/xla/lib/python3.11/threading.py:982)
Execution Analysis: _bootstrap_inner (/opt/dev/conda/envs/xla/lib/python3.11/threading.py:1045)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 261, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 210, in _process_chunk
return [fn(*args) for args in chunk]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 210, in <listcomp>
return [fn(*args) for args in chunk]
^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 77, in _run_thread_per_device
replica_results = list(
^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
yield _result_or_cancel(fs.pop())
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
return fut.result(timeout)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 70, in _thread_fn
return fn()
^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 176, in __call__
self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 185, in generate
xm.mark_step()
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 891, in mark_step
torch_xla._XLAC._xla_step_marker(
RuntimeError: Error while lowering: [] aten::view_as_complex_copy, xla_shape=bf16[1,1,6,128]{3,2,1,0}, dynamic_dims: ()
XLA builder error: UNIMPLEMENTED: Complex component type is not implemented.:
Frames:
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 268, in <module>
main(args)
File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 240, in main
xmp.spawn(
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 200, in spawn
run_multiprocess(spawn_fn, start_method=start_method)
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 160, in run_multiprocess
replica_results = list(
^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 161, in <genexpr>
itertools.chain.from_iterable(
^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 620, in _chain_from_iterable_of_lists
for element in iterable:
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
yield _result_or_cancel(fs.pop())
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
return fut.result(timeout)
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
RuntimeError: Error while lowering: [] aten::view_as_complex_copy, xla_shape=bf16[1,1,6,128]{3,2,1,0}, dynamic_dims: ()
XLA builder error: UNIMPLEMENTED: Complex component type is not implemented.:
Frames:
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
okay, after hours debug, I found the root cause that I was misleaded by the XLA doc to set XLA_USE_BF16=1
, after using export XLA_USE_BF16=0
, the gemma inference with XLA works fine.