xla
xla copied to clipboard
[NVIDIA] Optimize deterministic scalar scatter
This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.
Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit
Bugs resolved: https://github.com/google/jax/issues/17844
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
- Is it possible to smash into one commit with a lot more detailed commit message?
I think it is doable, but won't PRs be squashed to merge?
- Could you provide microbenchmark results, esp. comparing deterministic and non-deterministic scatter performance? If the performance is comparable, maybe we could even try to make it deterministic by default?
This is provided in the evaluation section of the attached doc.
This is provided in the evaluation section of the attached doc
Could we duplicate some of it in the commit message? Google Docs don't tend to live for very long: the access can be pulled at any time, whereas the commit message is there forever.
I squashed the commits into one and included the microbenchmark results there. Please take another look
Linter complains with:
CheckLint found errors.
These lines are out of order.
xla/service/gpu/BUILD:1420-1656](http://google3/third_party/tensorflow/compiler/xla/service/gpu/BUILD:1420-1656)
Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
xla/service/scatter_utils.h:61](http://google3/third_party/tensorflow/compiler/xla/service/scatter_utils.h?l=61)
Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
xla/service/scatter_determinism_expander.h:44](http://google3/third_party/tensorflow/compiler/xla/service/scatter_determinism_expander.h?l=44)
Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
xla/service/scatter_utils.cc:212](http://google3/third_party/tensorflow/compiler/xla/service/scatter_utils.cc?l=212)