jaxDecomp icon indicating copy to clipboard operation
jaxDecomp copied to clipboard

Currently unsupported features and caveats (aka the big TODO list)

Open EiffL opened this issue 2 years ago • 1 comments

Because implementing all features is not necessarily interesting unless there is a need for it, here are the current restrictions of the code in its current version. All these restrictions can be lifted if you have a need for them, don't hesitate to comment on this issue if there is something you would like to be able to do!

General

  • [x] Configuration mechanism to choose the communication API (by default CUDA aware MPI) only works at initialization
  • [x] No interface to run autotuning to figure out the best communication strategy
  • [x] No interface with the JAX 0.4 Array API

Transpose operations

  • [ ] Only transpose operations that preserve the size of local slices are supported
  • [ ] Transpose ops do not have batching or gradient operations implemented
  • [x] No support for letting XLA allocate the workspace size needed for the transpose
  • [x] Double precision is silently not supported (returns crazy values!)

FFTs

  • [ ] Only complex FFTs are implemented as a single CUDA-level operation
  • [ ] Only FFTs operations that preserve the size of local slices are supported

EiffL avatar Nov 24 '22 18:11 EiffL

Adding to the todo list

  • [x] update CMake for cuda 12 ; starting jax 0.4.26 cuda11 support is dropped
  • [x] Clean up inner and outer primitives and use NVIDIA TE class

ASKabalan avatar Mar 29 '24 19:03 ASKabalan

One last thing

  • [ ] support rfft irfft and rfftfreq

ASKabalan avatar Oct 30 '24 15:10 ASKabalan