jaxDecomp
jaxDecomp copied to clipboard
Fix 3D FFT
This PR will aim to fix the issues with 3D FFTs, right now, it just changes the FFT test to test the result of the forward fft, and this test fails.
@aboucaud @EiffL , This is ready to go FFTs and IFFTs are compatible with pencilx XY and YZ slabs, and they are all compatible with cubes, and non-cubes. Forward FFTs give back Z-pencils if the decomposition is pencils or XY slabs, and it gives Y-pencils for YZ slabs.
in the demo branch I added two scripts
- transpose_demo This demo shows how the sharding changes from col to row major and then the next transpose changes from row to col major Pdims is 2D at all times
- This script sharding shows how can we make col major processor grid in JAX
This branch purposly does not switch the sharding to show how it works (this tranposition is fixed in the Tranpose op branch)
In our case, we always start with col major