jaxDecomp icon indicating copy to clipboard operation
jaxDecomp copied to clipboard

JAX bindings for the NVIDIA cuDecomp library

Results 10 jaxDecomp issues
Sort by recently updated
recently updated
newest added

This PR will aim to fix the issues with 3D FFTs, right now, it just changes the FFT test to test the result of the forward fft, and this test...

First pull request for Tranpose ops They work for cubes for now, (Pencils and slabs)

Fixed slice_unpad, now it supports complex numbers

- [ ] slide_unpad does not seem to be too happy with getting complex numbers as inputs: ``` File "/mnt/home/flanusse/repo/jaxDecomp/scripts/demo.py", line 58, in recarray = slice_unpad(exchanged_reduced, padding_width, pdims) jaxlib.xla_extension.XlaRuntimeError: INTERNAL:...

Currently in the CMake, the cuda version is set to be 12.2. jaxDecomp (and cuDecomp) can be compiled with 11.8 (no specific cuda 12 code) JAX 0.4.26 and above no...

enhancement

Comparing the 3D FFT computed by jaxdecomp and manually in jax, I realized that the result of fft3d does not match with the non-distributed version. This could be due to...

bug

Because implementing all features is not necessarily interesting unless there is a need for it, here are the current restrictions of the code in its current version. **All these restrictions...

A change in the jaxDecomp internals to allow single transposition for slab decompositions Keeping the interface the same

Finishing the paper

Adding the possibility to do non-contiguous FFTs and transposes