jax icon indicating copy to clipboard operation
jax copied to clipboard

[DRAFT] Custom gpu ops custom partitioning

Open nouiz opened this issue 2 years ago • 1 comments

This in progress PR modify the docs/Custom_Operation_for_GPUs.py tutorial to use custom_partitioning instead of xmap.

Don't review now, there is still much work done.

  • [x] finish the forward code.
  • [x] finish the backward code.
  • [ ] update the documentation to explain it.
  • [ ] document debugging trick, like looking the the dump after sharding propagation, but before partitioning: module_0019.pjit__unnamed_function_.0011.spmd-partitioner.after_sharding-propagation.before_spmd-partitioning.txt, XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"
  • [ ] commit the full code to easy starting from it.
  • [ ] make a test of the commited full code to prevent breaking it as it already happened.
  • [ ] modify the documentation to inline part of the commited code instead of duplicating it.

nouiz avatar Dec 01 '23 15:12 nouiz

Replaced by https://github.com/google/jax/pull/20179

nouiz avatar Mar 13 '24 00:03 nouiz

Replaced by https://github.com/google/jax/pull/20179

nouiz avatar Apr 01 '24 13:04 nouiz