Update XLA torch.gather heuristics for multi-platforms
🚀 Feature
There are two code paths for composing and executing torch.gather in PyTorch/XLA, that relies on a custom heuristic to decide between the two. The heuristic requires a major update to better serve various platforms and use cases. cc @ymwangg @ronghanghu
Motivation
Based on the recent study for GPU/CPU, and also experiencing some related performance regression issue on TPU, we feel strongly that we would need a better heuristics to serve different platforms.
Pitch
There were a couple of different strategies employed to address the issue:
- forcing one type of
torch.gatherimplementation always, - making the decision based on the number of input tensor elements (w.r.t. the indices) and allow user to change/try a different threshold via setting an env var.
Both approaches had clear shortcomings, and had a diverging behaviors on different platforms. We propose that we improve our heuristics in the following ways:
- Pre-condition on the platform type
- Update the heuristics with a more conservative default threshold (advanced user can still change/try their own via setting an env var)
- Use the threshold-based heuristics only for TPU platform, for others
Alternatives
One of the two approaches that we've tried already.
Hi @ymwangg I will follow up with you once I create a PR to address this -- stay tuned!
@yeounoh Can this be closed?
This was fixed with https://github.com/pytorch/xla/pull/3629, closing.