GradDFT
GradDFT copied to clipboard
Demonstrate parallel execution of a loss function
On an HPC cluster, each term in a mean square loss can be calculated using embarrassingly parallel logic.
Unfortunately, the native way of doing this with jax (using jax.vmap and jax.pmap) is not compatible with input we must parallelize over: the Molecule object. This is because its data is stored in "ragged" structure. I.e., the dimensions of the grid for one molecule are very often different from the grid for another and the dimensions of the 1-RDM for one molecule is different for another: jnp.array([rdm1_1, rdm1_2]) will not work.
This means that for loss parallelism, we need to think differently. Sharding may be the way forward, but this requires more thought. A good reference is here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
I don't think we will get around to solving this problem before our release deadline, but if we want to do something with HPC, getting this right is non-negotiable.
@Matematija recommended sharding too.
Related to #83
Having playing around with the multiple hosts parallelism in JAX, I came across many issues on Perlmutter with the detection of GPUs.
I'm giving mpi4jax a go for this task now. It should be fairly easy if this works well on Perlmutter.