xla icon indicating copy to clipboard operation
xla copied to clipboard

[NVIDIA] Optimize deterministic scalar scatter

Open serach24 opened this issue 1 year ago • 1 comments

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: https://github.com/google/jax/issues/17844

serach24 avatar Oct 03 '24 03:10 serach24

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Oct 03 '24 03:10 google-cla[bot]

  1. Is it possible to smash into one commit with a lot more detailed commit message?

I think it is doable, but won't PRs be squashed to merge?

  1. Could you provide microbenchmark results, esp. comparing deterministic and non-deterministic scatter performance? If the performance is comparable, maybe we could even try to make it deterministic by default?

This is provided in the evaluation section of the attached doc.

serach24 avatar Oct 04 '24 16:10 serach24

This is provided in the evaluation section of the attached doc

Could we duplicate some of it in the commit message? Google Docs don't tend to live for very long: the access can be pulled at any time, whereas the commit message is there forever.

cheshire avatar Oct 15 '24 11:10 cheshire

I squashed the commits into one and included the microbenchmark results there. Please take another look

serach24 avatar Oct 15 '24 17:10 serach24

Linter complains with:

CheckLint found errors.
These lines are out of order.
	xla/service/gpu/BUILD:1420-1656](http://google3/third_party/tensorflow/compiler/xla/service/gpu/BUILD:1420-1656)
Could not find a newline character at the end of the file.  [whitespace/ending_newline] [5]
	xla/service/scatter_utils.h:61](http://google3/third_party/tensorflow/compiler/xla/service/scatter_utils.h?l=61)
Could not find a newline character at the end of the file.  [whitespace/ending_newline] [5]
	xla/service/scatter_determinism_expander.h:44](http://google3/third_party/tensorflow/compiler/xla/service/scatter_determinism_expander.h?l=44)
Could not find a newline character at the end of the file.  [whitespace/ending_newline] [5]
	xla/service/scatter_utils.cc:212](http://google3/third_party/tensorflow/compiler/xla/service/scatter_utils.cc?l=212)

ezhulenev avatar Oct 16 '24 09:10 ezhulenev