Support non-traceable Custom Ops with opaque arguments
🚀 Feature
torch_xla.stablehlo supports exporting custom op to stablehlo custom call for tensors arguments. We would like to be able to export custom ops taking arbitrary opaque string as argument to stable hlo.
Motivation
Some custom operations come from C external sources and are used through pybindings during inference.
Those operations sometimes take POD structures that are not necessarily tensors as argument, a little bit like the opaque Descriptor example in the Jax custom op tutorial.
Such operations can be used at any point in the model, they usually are ([opaque structs]) -> tensors, or (tensors) -> [opaque struct], but we could imagine an op in the middle of a model having side effect to an opaque external structure.
Pitch
Here is the example pytorch codes and what the HLO could potentially look like.
The idea is to be able to declare some arguments as "external" for the export to have them in the upper function and annotate them with some attributes, which would be used downstream to lower to some opaque pointers and sizes.
My example is based of https://github.com/pytorch/xla/pull/7017
@impl(m, "custom_op_external", "XLA")
def custom_op_external_xla(external_input, x):
res = stablehlo_custom_call((external_input,x), "custom_op_external", [(external.shape[1], ), x.shape[1:]],
[torch.int8, torch.int8], True, "backend_config", 1)
return res
class M(torch.nn.Module):
self.external = torch.empty(32)
def forward(self, x):
x = torch.sin(x)
x = torch.ops.my_custom_library.custom_op_external(self.external, x)
x = x + 1
return x
ep = torch.export.export(M(), (torch.randn(3, 3), ))
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
module @IrToHlo.10 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<32xi8> {external: true}) -> tensor<3xi8> {
%c = stablehlo.constant dense<1> : tensor<3xi8>
%0 = stablehlo.sine %arg0 : tensor<3x3xf32>
%1 = stablehlo.custom_call @custom_op_external(%arg1, %0) {backend_config = "backend_config", has_side_effect = true } : (tensor<3x3xf32>) -> tensor<3xi8>
%2 = stablehlo.add %1, %c : tensor<3xi8>
return %2 : tensor<3xi8>
}
}
@qihqi @lsy323
Currently I don't think you can register a custom op to torch with types that are not defined in native_functions.yml. Which, they do have str as dtype.
Also, I am curious why not just use int tensors to hold bytes as you have shown it the example above. That should already works.
Hi, @qihqi, is that ok to assign this ticket to you?
Also, I am curious why not just use int tensors to hold bytes as you have shown it the example above. That should already works.
We would like to expose already testesd and optimized implementations in C, that do not necessarily take tensors.
We coud definitely use a "Tensor" object to hold opaque / random bytes and reinterpret cast in the implementation (and that is some of the idea), but to know if suche a tensor holds an actual tensor or an opaque string, it needs to be annotated somehow when we export from torch to HLO. The annotations can then be used to lower the custom call arguments accordingly to opaque pointers.
We could maybe use https://github.com/pytorch/xla/pull/7046 to introduce annotation, but i was thinking maybe we could find a more generic solution.
Hi, @lsy323, is that ok to assign this ticket to you too? since @qihqi is OOO now