Add Experimental limited sparse embedding bag
Users of torch_xla encounter an issue when using the sparse=True option with the Embedding or EmbeddingBag modules.
The gradient for weight is created as a sparse tensor and there is no dispatch registered for the combination of sparse creation APIs w/ the XLA key, or the Sparse functionality key and the XLA backed key used in conjunction.
This is a workaround that can be removed, ported to C++, or extended later:
- SparseCOOTensor: a tensor subclass implementing the optimization and semantics of upstream SparseTensor. it is Composabile with the XLA device.
- drop in replacements for
F.embeddingF.embedding_bag,nn.Embedding, andnn.EmbeddingBagwhich forward to a custom implementation of the backward and produce the above tensor subclass rather than a native torch sparse tensor.
The tensor subclass may have component tensors indices and values which have xla device without issue.
fixes #8719
Hi @amjames if this passes test and is finished feel free to publish as PR and merge.
A note for reviewers: the failure in xla_op1 shard appears to be unrelated, the tests for the new feature are in xla_op3
I think it would be nice to have tests for each operation implemented. What do you think?