xla
xla copied to clipboard
Make einsum a leaf
#1225
XLA has a optimized einsum implementation that we can use. Requires a change in upstream.
@bdhirsh Do you think this is possible?
@JackCaoG This is the issue einsum op we just mentioned. From our earlier profiling, lowering it could potentially bring 5%+ speed up to several models. However, as mentioned in this issue and #2385, it requires an upstream change in PyTorch to dispatch it.
(Oh, it seems that we raced on commenting and you're already on this thread)
@ezyang in case you have some insight 😄
we need to write a backward kernel for einsum. Do you have an einsum_backward op? We could stub one in and just not have an implementation in PT proper
I didn't find anything by doing a quick search, will check with xla team regarding einsum backward.
So there isn't einsum_backward for XLA nativelly but after I talked with Blake I think we can implement that using the einsum. In Blake's word
it is easy enough to do in pytorch xla or in xla client
you just need to parse the comma string
and swap the operand string and output string
for the operand you want to take the derriative with respect to
so like "...a,...ab->...b" would got to ...a,..b->...ab
and and ...b,...ab->...a
to get operand 1 and operand 0 gradients respectively
@ezyang If you could make einsum and einsum_bakward as leaf nodes, I will try to lower them using xla::Einsum and test it for the pytorch/xla.
I don't think we need to do anything in PyTorch; just add einsum to the autograd list in xla_native_functions.yaml and then implement the custom autograd function the same as the other ops. We could upstream them but this is probably easiest.
Oh OK. It seems like einsum is already something we can lower https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L1892. I will try that. Thanks Ed!
@steventk-g You can use https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_autograd_ops.h#L10 as an example to write both forward and backward part for einsum. You need to put einsum under here
After https://github.com/pytorch/xla/pull/3843, we will need changes to support (1) einsum on more than 2 inputs and (2) einsum on equations like ijj,k->ik, where one input or output has an index that none of the other inputs or outputs have. For now, we fall back to the at::native implementation in these cases.