xla icon indicating copy to clipboard operation
xla copied to clipboard

[DETR support] Lower aten::_cdist_forward

Open ymwangg opened this issue 2 years ago • 2 comments

This op is needed to support DETR model. We'll use this thread to track the progress.

  • Owner: @codeislife99

ymwangg avatar Aug 25 '22 01:08 ymwangg

I think this op may not need to be implemented using XLA custom call. I found it can be lowered using broadcast + diff + reduce. You can check out the code here https://github.com/ymwangg/xla/commit/2191bef52d5a061febbf1d490ab5735cb92fae2f.

There are some issues that are not handled in this prototype and needs more work:

  1. How to efficiently handle the case when p = 1, 2. For these cases, xla::pow is not needed. Maybe xla can optimize this through constant folding but we need to verify it.
  2. How to handle special cases like p = 0, inf.
  3. Add test cases.
  4. Maybe there are other algorithms that are more efficient (e.g. using xla::dot). cc @codeislife99

ymwangg avatar Sep 14 '22 00:09 ymwangg

I think we can use xla::norm directly. It supports p=0 and p=inf. So it should address all your comments

codeislife99 avatar Sep 14 '22 01:09 codeislife99