jaxDecomp
jaxDecomp copied to clipboard
JAX bindings for the NVIDIA cuDecomp library
This PR will aim to fix the issues with 3D FFTs, right now, it just changes the FFT test to test the result of the forward fft, and this test...
First pull request for Tranpose ops They work for cubes for now, (Pencils and slabs)
Fixed slice_unpad, now it supports complex numbers
- [ ] slide_unpad does not seem to be too happy with getting complex numbers as inputs: ``` File "/mnt/home/flanusse/repo/jaxDecomp/scripts/demo.py", line 58, in recarray = slice_unpad(exchanged_reduced, padding_width, pdims) jaxlib.xla_extension.XlaRuntimeError: INTERNAL:...
Currently in the CMake, the cuda version is set to be 12.2. jaxDecomp (and cuDecomp) can be compiled with 11.8 (no specific cuda 12 code) JAX 0.4.26 and above no...
Comparing the 3D FFT computed by jaxdecomp and manually in jax, I realized that the result of fft3d does not match with the non-distributed version. This could be due to...
Because implementing all features is not necessarily interesting unless there is a need for it, here are the current restrictions of the code in its current version. **All these restrictions...
A change in the jaxDecomp internals to allow single transposition for slab decompositions Keeping the interface the same
Finishing the paper
Adding the possibility to do non-contiguous FFTs and transposes