PDEBench icon indicating copy to clipboard operation
PDEBench copied to clipboard

Performance Bottleneck in `pm_evolve` When Running on CPU

Open panjiashu opened this issue 1 year ago • 0 comments

Thank you for the great work! 👍

I installed the environment following the instruction, and am attempting to generate 2D compressible NS data.

With numbers=1 in 2D_multi_Rand.yaml, I run sh run_trainset_2D.sh with the following content: nn=1 key=2031 while [ $nn -le 1 ]; do python3 CFD_multi_Hydra.py +args=2D_Multi_Rand.yaml ++args.init_key=$key nn=$(expr $nn + 1) key=$(expr $key + 1) echo "$nn" echo "$key" done It seems by default the cpu version of jax is used, so I removed CUDA_VISIBLE_DEVICES='0,1,2,3'.

The program completes successfully, with the evolve function taking 0.71s, and the total runtime being 278s. Increasing nx, ny from 128 to 256 will make the program run 80mins (~20x the time for the smaller grid). I have located the latency to occur at the return of pm_evolve call, where t, DDD, VVx, VVy, VVz, PPP can be computed in <1s. Since pm_evolve is a jax-wrapped version of evolve, the problem must be from pmap or vmap.

I’ve confirmed that the backend is CPU. Do you have any advice on optimizing this or understanding why the return step is taking so long?

Thank you for your time and assistance!

panjiashu avatar Oct 20 '24 15:10 panjiashu