jaxDecomp icon indicating copy to clipboard operation
jaxDecomp copied to clipboard

Fix 3D FFT

Open EiffL opened this issue 9 months ago • 2 comments

This PR will aim to fix the issues with 3D FFTs, right now, it just changes the FFT test to test the result of the forward fft, and this test fails.

EiffL avatar Apr 28 '24 13:04 EiffL

@aboucaud @EiffL , This is ready to go FFTs and IFFTs are compatible with pencilx XY and YZ slabs, and they are all compatible with cubes, and non-cubes. Forward FFTs give back Z-pencils if the decomposition is pencils or XY slabs, and it gives Y-pencils for YZ slabs.

ASKabalan avatar May 04 '24 13:05 ASKabalan

in the demo branch I added two scripts

  • transpose_demo This demo shows how the sharding changes from col to row major and then the next transpose changes from row to col major Pdims is 2D at all times
  • This script sharding shows how can we make col major processor grid in JAX

This branch purposly does not switch the sharding to show how it works (this tranposition is fixed in the Tranpose op branch)

In our case, we always start with col major

ASKabalan avatar May 04 '24 13:05 ASKabalan