xla
xla copied to clipboard
[DETR support] Lower aten::_cdist_forward
This op is needed to support DETR model. We'll use this thread to track the progress.
- Owner: @codeislife99
I think this op may not need to be implemented using XLA custom call. I found it can be lowered using broadcast + diff + reduce. You can check out the code here https://github.com/ymwangg/xla/commit/2191bef52d5a061febbf1d490ab5735cb92fae2f.
There are some issues that are not handled in this prototype and needs more work:
- How to efficiently handle the case when p = 1, 2. For these cases, xla::pow is not needed. Maybe xla can optimize this through constant folding but we need to verify it.
- How to handle special cases like p = 0, inf.
- Add test cases.
- Maybe there are other algorithms that are more efficient (e.g. using xla::dot). cc @codeislife99
I think we can use xla::norm directly. It supports p=0 and p=inf. So it should address all your comments