GradDFT icon indicating copy to clipboard operation
GradDFT copied to clipboard

Demonstrate parallel execution of a loss function

Open jackbaker1001 opened this issue 2 years ago • 3 comments

On an HPC cluster, each term in a mean square loss can be calculated using embarrassingly parallel logic.

Unfortunately, the native way of doing this with jax (using jax.vmap and jax.pmap) is not compatible with input we must parallelize over: the Molecule object. This is because its data is stored in "ragged" structure. I.e., the dimensions of the grid for one molecule are very often different from the grid for another and the dimensions of the 1-RDM for one molecule is different for another: jnp.array([rdm1_1, rdm1_2]) will not work.

This means that for loss parallelism, we need to think differently. Sharding may be the way forward, but this requires more thought. A good reference is here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

I don't think we will get around to solving this problem before our release deadline, but if we want to do something with HPC, getting this right is non-negotiable.

jackbaker1001 avatar Sep 12 '23 18:09 jackbaker1001

@Matematija recommended sharding too.

PabloAMC avatar Sep 12 '23 18:09 PabloAMC

Related to #83

jackbaker1001 avatar Dec 07 '23 17:12 jackbaker1001

Having playing around with the multiple hosts parallelism in JAX, I came across many issues on Perlmutter with the detection of GPUs.

I'm giving mpi4jax a go for this task now. It should be fairly easy if this works well on Perlmutter.

jackbaker1001 avatar Dec 11 '23 19:12 jackbaker1001