Francois Lanusse

Results 58 comments of Francois Lanusse

@eelregit Making progress on this ^^ jaxdecomp is now able to do forward and backward FFTs https://github.com/DifferentiableUniverseInitiative/jaxDecomp Still have to add a few things, but really not far away from...

Ok, I've added halo exchange and cleaned up the interface. Also added gradients of these operations. You can also select which backend you want to use, MPI, NCCL, or NVSHMEM....

One 1024^3 FFT on 4 V100 GPUs... 0.5 ms :rofl:

Annnnd 50ms for a 2048^3 FFT on 16 V100 GPUs on 2 nodes... :exploding_head: (also tagging @modichirag )

0.5ms does sound really fast, but the result of the FFT seems to be correct, so.... maybe? I'm not 100% sure what scaling we should expect, as a function of...

oups ^^' you are right, I didnt include a block until ready... New timings on V100s: - 2048^3 on 16 GPUs using NCCL: 8s - 1024^3 on 4 GPUs using...

Yep, it should be trivial to add an op for cufftmp in jaxdecomp as it's already part of the nvhpc SDK, so no need to compile an external library :-)...

@eelregit ... ok, took about a year, but it's now working nicely with the latest version of JAX thanks to the heroic efforts of my collaborator @ASKabalan :-) We have...

@eelregit here is a minimal demo of LPT implemented using jaxdecomp: https://github.com/DifferentiableUniverseInitiative/jaxDecomp/blob/main/examples/lpt_nbody_demo.py

@eelregit timings https://flanusse.net/talks/Split2024/#/15/0/1