xla icon indicating copy to clipboard operation
xla copied to clipboard

Update XLA torch.gather heuristics for multi-platforms

Open yeounoh opened this issue 3 years ago • 2 comments

🚀 Feature

There are two code paths for composing and executing torch.gather in PyTorch/XLA, that relies on a custom heuristic to decide between the two. The heuristic requires a major update to better serve various platforms and use cases. cc @ymwangg @ronghanghu

Motivation

Based on the recent study for GPU/CPU, and also experiencing some related performance regression issue on TPU, we feel strongly that we would need a better heuristics to serve different platforms.

Pitch

There were a couple of different strategies employed to address the issue:

  • forcing one type of torch.gather implementation always,
  • making the decision based on the number of input tensor elements (w.r.t. the indices) and allow user to change/try a different threshold via setting an env var.

Both approaches had clear shortcomings, and had a diverging behaviors on different platforms. We propose that we improve our heuristics in the following ways:

  • Pre-condition on the platform type
  • Update the heuristics with a more conservative default threshold (advanced user can still change/try their own via setting an env var)
  • Use the threshold-based heuristics only for TPU platform, for others

Alternatives

One of the two approaches that we've tried already.

yeounoh avatar May 19 '22 19:05 yeounoh

Hi @ymwangg I will follow up with you once I create a PR to address this -- stay tuned!

yeounoh avatar May 19 '22 19:05 yeounoh

@yeounoh Can this be closed?

JackCaoG avatar Jun 30 '22 21:06 JackCaoG

This was fixed with https://github.com/pytorch/xla/pull/3629, closing.

yeounoh avatar Aug 31 '22 00:08 yeounoh