xla icon indicating copy to clipboard operation
xla copied to clipboard

Lowering aten::where encounters mismatched types and crashed

Open tpppppub opened this issue 1 year ago • 4 comments

🐛 Bug

Hi Pytorch/XLA developers,

I try to dump a Huggingface Opt model into Stablehlo with torch_xla. The model is exported by torch successfully but crashed when lowering the exported program to stablehlo at the Pytorch/XLA side. And it seems to be a bug in Pytorch/XLA's lowering of the aten::where op.

Any help would be appreciated.

To Reproduce

Steps to reproduce the behavior:

  1. install the huggingface tranformers lib
pip install transformers==4.37.2
  1. run the following scripts
from transformers import OPTForCausalLM

import torch
from torch_xla import stablehlo


model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b")
model.eval()
input = torch.randint(0, 100, size=(1, 108))
exported_fn = torch.export.export(model, args=(input,))
# works fine until here
# print(exported_fn)

options = stablehlo.StableHLOExportOptions()
# options.override_tracing_arguments = (m_args,)
shlo = stablehlo.exported_program_to_stablehlo(exported_fn, options)
method = shlo._name_to_stablehlo["forward"]
ir_str = method.text
print(ir_str)
  1. see the output
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707223682.454894 2471766 cpu_client.cc:370] TfrtCpuClient created.
Traceback (most recent call last):
  File "test.py", line 16, in <module>
    shlo = stablehlo.exported_program_to_stablehlo(exported_fn, options)
  File "/home/dev/miniconda3/envs/spu/lib/python3.8/site-packages/torch_xla/stablehlo.py", line 541, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
  File "/home/dev/miniconda3/envs/spu/lib/python3.8/site-packages/torch_xla/stablehlo.py", line 326, in _exported_program_to_stablehlo_bundle
    stablehlo_content = xm.get_stablehlo_bytecode(res)
  File "/home/dev/miniconda3/envs/spu/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 947, in get_stablehlo_bytecode
    return torch_xla._XLAC._get_stablehlo(
RuntimeError: Error while lowering: [] aten::where, xla_shape=f32[1,1,108,108]{3,2,1,0}, dynamic_dims: ()
Error: torch_xla/csrc/helpers.cpp:623 : Check failed: xla::ShapeUtil::SameElementType(shape1, shape2) 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::XlaHelpers::PromoteShapes(xla::XlaOp, xla::XlaOp)
	
	torch_xla::Generic::Lower(torch_xla::LoweringContext*) const
	torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
	torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
	torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
	torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
	torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
	
	
	
	PyCFunction_Call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	PyEval_EvalCodeEx
	PyEval_EvalCode
	
	
	
	PyRun_SimpleFileExFlags
	Py_RunMain
	Py_BytesMain
	
	__libc_start_main
	
*** End stack trace ***
f32[1,1,108,108]{3,2,1,0} and f64[1,1,108,108]{3,2,1,0}
Frames:

I0000 00:00:1707223684.467343 2471766 cpu_client.cc:373] TfrtCpuClient destroyed.

From the log above, it's said the f32[1,1,108,108]{3,2,1,0} and f64[1,1,108,108]{3,2,1,0} operands from a where op are mismatched. I am confused where does the dtype of f64 come from? Because I did not find it in the exported torch program.

Expected behavior

The text of the stablehlo of this model's forward function was printed.

Environment

  • Reproducible on XLA backend [CPU/TPU]: CPU
  • torch_xla version: 2.2.0
  • torch version: 2.2.0
  • transformers version: 4.37.2

Additional context

After some digging, I found a related ISSUE and a fixing PR, but this PR was not merged.

tpppppub avatar Feb 06 '24 13:02 tpppppub

Thanks for the additional context, due to the failure error when dump model to StableHLO:

RuntimeError: Error while lowering: [] aten::where, xla_shape=f32[1,1,108,108]{3,2,1,0}, dynamic_dims: ()
Error: torch_xla/csrc/helpers.cpp:623 : Check failed: xla::ShapeUtil::SameElementType(shape1, shape2) 

do you have more context of this issue? @qihqi, and is it ok to assign this ticket to you?

ManfeiBai avatar Feb 09 '24 23:02 ManfeiBai

Any progress? @qihqi @ManfeiBai I also encountered the same issue.

llCurious avatar Feb 20 '24 08:02 llCurious

Hi, sorry for the delay.

Just tried on HEAD and your example is passing. My guess is that is is fixed after https://github.com/pytorch/xla/pull/6460

So we introduces f64 when scalars are used in binary ops (because python float is float64 on most platforms). The above PR fixes that.

qihqi avatar Feb 22 '24 00:02 qihqi

Thx for the reply. So it will be released in torch-xla v2.2.1 ?

tpppppub avatar Feb 22 '24 02:02 tpppppub

Same, when inference official Gemma xla Pytorch code version on TPUv4-8.

python run_xla.py  --ckpt "/home/m0niusplus/gemma_pytorch/gemma-2b-pytorch/gemma-2b.ckpt"
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 261, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 210, in _process_chunk
    return [fn(*args) for args in chunk]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 210, in <listcomp>
    return [fn(*args) for args in chunk]
            ^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 77, in _run_thread_per_device
    replica_results = list(
                      ^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
    yield _result_or_cancel(fs.pop())
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
    return fut.result(timeout)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 70, in _thread_fn
    return fn()
           ^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 176, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 185, in generate
    xm.mark_step()
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 891, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Error while lowering: [] aten::view_as_complex_copy, xla_shape=bf16[1,1,6,128]{3,2,1,0}, dynamic_dims: ()
XLA builder error: UNIMPLEMENTED: Complex component type is not implemented.:
Frames:

"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 267, in <module>
    main(args)
  File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 239, in main
    xmp.spawn(
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 200, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 160, in run_multiprocess
    replica_results = list(
                      ^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 161, in <genexpr>
    itertools.chain.from_iterable(
                                 ^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 620, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
    yield _result_or_cancel(fs.pop())
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
    return fut.result(timeout)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
RuntimeError: Error while lowering: [] aten::view_as_complex_copy, xla_shape=bf16[1,1,6,128]{3,2,1,0}, dynamic_dims: ()
XLA builder error: UNIMPLEMENTED: Complex component type is not implemented.:
Frames:

With PT_XLA_DEBUG=1

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   user mark_step
Compilation Analysis: Graph Info:
Compilation Analysis:   Graph Hash: 2f09bda1ca0ff6e0fbf3925676711b1
Compilation Analysis:   Number of Graph Inputs: 3
Compilation Analysis:   Number of Graph Outputs: 41
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis:   mark_step (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/core/xla_model.py:891)
Compilation Analysis:   generate (/home/m0niusplus/gemma_pytorch/run_xla.py:156)
Compilation Analysis:   __call__ (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py:176)
Compilation Analysis:   _thread_fn (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py:70)
Compilation Analysis:   run (/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py:58)
Compilation Analysis:   _worker (/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py:83)
Compilation Analysis:   run (/opt/dev/conda/envs/xla/lib/python3.11/threading.py:982)
Compilation Analysis:   _bootstrap_inner (/opt/dev/conda/envs/xla/lib/python3.11/threading.py:1045)
Compilation Analysis:   ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   user mark_step
Execution Analysis: Graph Info:
Execution Analysis:   Graph Hash: 2f09bda1ca0ff6e0fbf3925676711b1
Execution Analysis:   Number of Graph Inputs: 3
Execution Analysis:   Number of Graph Outputs: 41
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis:   mark_step (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/core/xla_model.py:891)
Execution Analysis:   generate (/home/m0niusplus/gemma_pytorch/run_xla.py:156)
Execution Analysis:   __call__ (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py:176)
Execution Analysis:   _thread_fn (/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py:70)
Execution Analysis:   run (/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py:58)
Execution Analysis:   _worker (/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py:83)
Execution Analysis:   run (/opt/dev/conda/envs/xla/lib/python3.11/threading.py:982)
Execution Analysis:   _bootstrap_inner (/opt/dev/conda/envs/xla/lib/python3.11/threading.py:1045)
Execution Analysis:   ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 261, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 210, in _process_chunk
    return [fn(*args) for args in chunk]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 210, in <listcomp>
    return [fn(*args) for args in chunk]
            ^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 77, in _run_thread_per_device
    replica_results = list(
                      ^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
    yield _result_or_cancel(fs.pop())
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
    return fut.result(timeout)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 70, in _thread_fn
    return fn()
           ^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 176, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 185, in generate
    xm.mark_step()
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 891, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Error while lowering: [] aten::view_as_complex_copy, xla_shape=bf16[1,1,6,128]{3,2,1,0}, dynamic_dims: ()
XLA builder error: UNIMPLEMENTED: Complex component type is not implemented.:
Frames:

"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 268, in <module>
    main(args)
  File "/home/m0niusplus/gemma_pytorch/run_xla.py", line 240, in main
    xmp.spawn(
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 200, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/runtime.py", line 87, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 160, in run_multiprocess
    replica_results = list(
                      ^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/site-packages/torch_xla/_internal/pjrt.py", line 161, in <genexpr>
    itertools.chain.from_iterable(
                                 ^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/process.py", line 620, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
    yield _result_or_cancel(fs.pop())
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
    return fut.result(timeout)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/dev/conda/envs/xla/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
RuntimeError: Error while lowering: [] aten::view_as_complex_copy, xla_shape=bf16[1,1,6,128]{3,2,1,0}, dynamic_dims: ()
XLA builder error: UNIMPLEMENTED: Complex component type is not implemented.:
Frames:

WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================

Mon-ius avatar Mar 18 '24 09:03 Mon-ius

okay, after hours debug, I found the root cause that I was misleaded by the XLA doc to set XLA_USE_BF16=1, after using export XLA_USE_BF16=0, the gemma inference with XLA works fine.

Mon-ius avatar Mar 18 '24 11:03 Mon-ius