xla icon indicating copy to clipboard operation
xla copied to clipboard

Failing `torch_xla._XLAC._xla_custom_call()` with `RuntimeError: Bad StatusOr access: UNIMPLEMENTED: No registered implementation for custom call to my_lib.my_op.default for platform CUDA`

Open hsjts0u opened this issue 6 months ago • 4 comments

❓ Questions and Help

During execution of torch_xla.stablehlo.exported_program_to_stablehlo(), it fails with RuntimeError: Bad StatusOr access: UNIMPLEMENTED: No registered implementation for custom call to my_lib.my_op.default for platform CUDA. For more context, my_op is registered under a custom library as follows

from torch.library import Library, impl
from torch.library import impl_abstract

MY_LIB = Library("my_lib", "DEF")

MY_LIB.define("my_op(Tensor t) -> Tensor")


@impl(f"{MY_LIB.ns}::my_op", "default")
def my_op(t):
    return t


@impl_abstract(f"{MY_LIB.ns}::my_op")
def my_op_meta(t):
    return torch.empty_like(t)

I am able to get the torch ExportedProgram and the MY_LIB namespace is allowed in the stablehlo graph as a custom op by specifying

StableHLOExportOptions(
    custom_ops_allowed_in_graph={MY_LIB.ns}
)

It seems to me that if XLA does not attempt to execute the graph then the error is not thrown. I have a few questions here:

  1. How can I get around this RuntimeError?
  2. Does registering a custom op under torch library (the way I did in the first code snippet) not expose the implementation to XLA?

hsjts0u avatar Jun 16 '25 21:06 hsjts0u

Could you post the raised error together with the backtrace, please? @bhavya01 @tengyifei Any thoughts?

ysiraichi avatar Jun 17 '25 11:06 ysiraichi

...

  File "/home/htsou/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 600, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/htsou/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 353, in _exported_program_to_stablehlo_bundle
    res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/htsou/.local/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/htsou/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 296, in run_node
    return super().run_node(n)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/htsou/.local/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/htsou/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 276, in call_function
    return super().call_function(target, args, new_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/htsou/.local/lib/python3.11/site-packages/torch/fx/interpreter.py", line 275, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/htsou/.local/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Bad StatusOr access: UNIMPLEMENTED: No registered implementation for custom call to my_lib.my_op.default for platform CUDA

While executing %mul_6826 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%my_op_75926, 1.0), kwargs = {})
Original traceback:
  File "<eval_with_key>.31122", line 685, in forward
    inner__tensor_constant_policy_0_rearrange_264_folded = self.inner._tensor_constant_policy_0_rearrange_264_folded

hsjts0u avatar Jun 18 '25 04:06 hsjts0u

@lsy323 Do you know if StableHLO export works for custom calls?

bhavya01 avatar Jun 24 '25 17:06 bhavya01

It seems to me that XLA should only ever be lazy during the call of _exported_program_to_stablehlo. However, when the number of nodes in the graph exceed a certain threshold then XLA wants to materialize the tensor and hence execute the graph. For whatever reason custom calls registered through torch library is not accessible to XLA, which I don't know if it was ever supposed to be in the first place. By reducing the number of nodes it seems to run _exported_program_to_stablehlo just fine. Can someone confirm that XLA materializes the graph if the number of operations/nodes exceeds a certain threshold and if that is true whether there is a knob to tune that threshold?

hsjts0u avatar Jun 24 '25 18:06 hsjts0u