xla icon indicating copy to clipboard operation
xla copied to clipboard

Support non-traceable Custom Ops with opaque arguments

Open tle-huu opened this issue 1 year ago • 5 comments

🚀 Feature

torch_xla.stablehlo supports exporting custom op to stablehlo custom call for tensors arguments. We would like to be able to export custom ops taking arbitrary opaque string as argument to stable hlo.

Motivation

Some custom operations come from C external sources and are used through pybindings during inference.

Those operations sometimes take POD structures that are not necessarily tensors as argument, a little bit like the opaque Descriptor example in the Jax custom op tutorial.

Such operations can be used at any point in the model, they usually are ([opaque structs]) -> tensors, or (tensors) -> [opaque struct], but we could imagine an op in the middle of a model having side effect to an opaque external structure.

Pitch

Here is the example pytorch codes and what the HLO could potentially look like.

The idea is to be able to declare some arguments as "external" for the export to have them in the upper function and annotate them with some attributes, which would be used downstream to lower to some opaque pointers and sizes.

My example is based of https://github.com/pytorch/xla/pull/7017

@impl(m, "custom_op_external", "XLA")
def custom_op_external_xla(external_input, x):
  res = stablehlo_custom_call((external_input,x), "custom_op_external", [(external.shape[1], ), x.shape[1:]],
                                  [torch.int8, torch.int8], True, "backend_config", 1)
  return res

class M(torch.nn.Module):

   self.external = torch.empty(32)

  def forward(self, x):
    x = torch.sin(x)
    x = torch.ops.my_custom_library.custom_op_external(self.external, x)
    x = x + 1
    return x

ep = torch.export.export(M(), (torch.randn(3, 3), ))
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
module @IrToHlo.10 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<32xi8> {external: true}) -> tensor<3xi8> {
    %c = stablehlo.constant dense<1> : tensor<3xi8>
    %0 = stablehlo.sine %arg0 : tensor<3x3xf32>
    %1 = stablehlo.custom_call @custom_op_external(%arg1, %0) {backend_config = "backend_config", has_side_effect = true } : (tensor<3x3xf32>) -> tensor<3xi8>
    %2 = stablehlo.add %1, %c : tensor<3xi8>
    return %2 : tensor<3xi8>
  }
}

tle-huu avatar Jun 22 '24 01:06 tle-huu

@qihqi @lsy323

JackCaoG avatar Jun 24 '24 17:06 JackCaoG

Currently I don't think you can register a custom op to torch with types that are not defined in native_functions.yml. Which, they do have str as dtype.

Also, I am curious why not just use int tensors to hold bytes as you have shown it the example above. That should already works.

qihqi avatar Jun 25 '24 00:06 qihqi

Hi, @qihqi, is that ok to assign this ticket to you?

ManfeiBai avatar Jun 25 '24 06:06 ManfeiBai

Also, I am curious why not just use int tensors to hold bytes as you have shown it the example above. That should already works.

We would like to expose already testesd and optimized implementations in C, that do not necessarily take tensors.

We coud definitely use a "Tensor" object to hold opaque / random bytes and reinterpret cast in the implementation (and that is some of the idea), but to know if suche a tensor holds an actual tensor or an opaque string, it needs to be annotated somehow when we export from torch to HLO. The annotations can then be used to lower the custom call arguments accordingly to opaque pointers.

We could maybe use https://github.com/pytorch/xla/pull/7046 to introduce annotation, but i was thinking maybe we could find a more generic solution.

tle-huu avatar Jun 25 '24 16:06 tle-huu

Hi, @lsy323, is that ok to assign this ticket to you too? since @qihqi is OOO now

ManfeiBai avatar Jul 01 '24 21:07 ManfeiBai