jaxDecomp
jaxDecomp copied to clipboard
Currently unsupported features and caveats (aka the big TODO list)
Because implementing all features is not necessarily interesting unless there is a need for it, here are the current restrictions of the code in its current version. All these restrictions can be lifted if you have a need for them, don't hesitate to comment on this issue if there is something you would like to be able to do!
General
- [x] Configuration mechanism to choose the communication API (by default CUDA aware MPI) only works at initialization
- [x] No interface to run autotuning to figure out the best communication strategy
- [x] No interface with the JAX 0.4 Array API
Transpose operations
- [ ] Only transpose operations that preserve the size of local slices are supported
- [ ] Transpose ops do not have batching or gradient operations implemented
- [x] No support for letting XLA allocate the workspace size needed for the transpose
- [x] Double precision is silently not supported (returns crazy values!)
FFTs
- [ ] Only complex FFTs are implemented as a single CUDA-level operation
- [ ] Only FFTs operations that preserve the size of local slices are supported
Adding to the todo list
- [x] update CMake for cuda 12 ; starting jax 0.4.26 cuda11 support is dropped
- [x] Clean up inner and outer primitives and use NVIDIA TE class
One last thing
- [ ] support rfft irfft and rfftfreq