jax
jax copied to clipboard
[DRAFT] Custom gpu ops custom partitioning
This in progress PR modify the docs/Custom_Operation_for_GPUs.py tutorial to use custom_partitioning instead of xmap.
Don't review now, there is still much work done.
- [x] finish the forward code.
- [x] finish the backward code.
- [ ] update the documentation to explain it.
- [ ] document debugging trick, like looking the the dump after sharding propagation, but before partitioning: module_0019.pjit__unnamed_function_.0011.spmd-partitioner.after_sharding-propagation.before_spmd-partitioning.txt, XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"
- [ ] commit the full code to easy starting from it.
- [ ] make a test of the commited full code to prevent breaking it as it already happened.
- [ ] modify the documentation to inline part of the commited code instead of duplicating it.
Replaced by https://github.com/google/jax/pull/20179
Replaced by https://github.com/google/jax/pull/20179