xla icon indicating copy to clipboard operation
xla copied to clipboard

Make einsum a leaf

Open ailzhang opened this issue 6 years ago • 11 comments

#1225 XLA has a optimized einsum implementation that we can use. Requires a change in upstream.

ailzhang avatar Nov 14 '19 19:11 ailzhang

@bdhirsh Do you think this is possible?

JackCaoG avatar Jul 08 '22 20:07 JackCaoG

@JackCaoG This is the issue einsum op we just mentioned. From our earlier profiling, lowering it could potentially bring 5%+ speed up to several models. However, as mentioned in this issue and #2385, it requires an upstream change in PyTorch to dispatch it.

ronghanghu avatar Jul 08 '22 21:07 ronghanghu

(Oh, it seems that we raced on commenting and you're already on this thread)

ronghanghu avatar Jul 08 '22 21:07 ronghanghu

@ezyang in case you have some insight 😄

JackCaoG avatar Jul 13 '22 22:07 JackCaoG

we need to write a backward kernel for einsum. Do you have an einsum_backward op? We could stub one in and just not have an implementation in PT proper

ezyang avatar Jul 14 '22 03:07 ezyang

I didn't find anything by doing a quick search, will check with xla team regarding einsum backward.

JackCaoG avatar Jul 14 '22 03:07 JackCaoG

So there isn't einsum_backward for XLA nativelly but after I talked with Blake I think we can implement that using the einsum. In Blake's word

it is easy enough to do in pytorch xla or in xla client
you just need to parse the comma string
and swap the operand string and output string
for the operand you want to take the derriative with respect to
so like "...a,...ab->...b" would got to ...a,..b->...ab
and and ...b,...ab->...a
to get operand 1 and operand 0 gradients respectively

@ezyang If you could make einsum and einsum_bakward as leaf nodes, I will try to lower them using xla::Einsum and test it for the pytorch/xla.

JackCaoG avatar Jul 14 '22 22:07 JackCaoG

I don't think we need to do anything in PyTorch; just add einsum to the autograd list in xla_native_functions.yaml and then implement the custom autograd function the same as the other ops. We could upstream them but this is probably easiest.

ezyang avatar Jul 21 '22 04:07 ezyang

Oh OK. It seems like einsum is already something we can lower https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L1892. I will try that. Thanks Ed!

JackCaoG avatar Jul 21 '22 17:07 JackCaoG

@steventk-g You can use https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_autograd_ops.h#L10 as an example to write both forward and backward part for einsum. You need to put einsum under here

JackCaoG avatar Aug 11 '22 23:08 JackCaoG

After https://github.com/pytorch/xla/pull/3843, we will need changes to support (1) einsum on more than 2 inputs and (2) einsum on equations like ijj,k->ik, where one input or output has an index that none of the other inputs or outputs have. For now, we fall back to the at::native implementation in these cases.

steventk-g avatar Sep 14 '22 23:09 steventk-g