xla icon indicating copy to clipboard operation
xla copied to clipboard

Add Experimental limited sparse embedding bag

Open amjames opened this issue 8 months ago • 3 comments

Users of torch_xla encounter an issue when using the sparse=True option with the Embedding or EmbeddingBag modules.

The gradient for weight is created as a sparse tensor and there is no dispatch registered for the combination of sparse creation APIs w/ the XLA key, or the Sparse functionality key and the XLA backed key used in conjunction.

This is a workaround that can be removed, ported to C++, or extended later:

  • SparseCOOTensor: a tensor subclass implementing the optimization and semantics of upstream SparseTensor. it is Composabile with the XLA device.
  • drop in replacements for F.embedding F.embedding_bag, nn.Embedding, and nn.EmbeddingBag which forward to a custom implementation of the backward and produce the above tensor subclass rather than a native torch sparse tensor.

The tensor subclass may have component tensors indices and values which have xla device without issue.

fixes #8719

amjames avatar Mar 28 '25 22:03 amjames

Hi @amjames if this passes test and is finished feel free to publish as PR and merge.

qihqi avatar Apr 11 '25 20:04 qihqi

A note for reviewers: the failure in xla_op1 shard appears to be unrelated, the tests for the new feature are in xla_op3

amjames avatar May 30 '25 20:05 amjames

I think it would be nice to have tests for each operation implemented. What do you think?

ysiraichi avatar Jun 05 '25 14:06 ysiraichi