qmcpack
qmcpack copied to clipboard
msd_fast is slow for non-orthogonal determinants
Describe the bug
When running the sum of 2 non-orthogonal Slater determinants using the msd_fast object (MultiSlaterDetTableMethod).
I get a 15x slow down compare to the single determinant run, instead of the (2+epsilon)x slow down I expected out of a "fast" algorithm.
To Reproduce Steps to reproduce the behavior:
- build commit 1e365df21c5a128879254dd8fb6798c149430966 from Fri Jul 14 2023
- cmake -D QMC_COMPLEX=1
- mpirun -np 64 qmcpack
Expected behavior 2-det VMC runs a bit over 2x slower than 1-det VMC.
System:
- Intel workstation (2x Intel(R) Xeon(R) Platinum 8362 CPU @ 2.80GHz)
- modules loaded: gcc/10.4.0 openmpi/4.0.7 intel-oneapi-mkl/2023.0.0 hdf5/mpi-1.10.9 boost/1.80.0
Additional context
In the attached archive tb36c_h54.zip,
the README file summarizes the issue.
The fine timer outputs indicate that buildTable and table2ratios take majority of the time in the 2-det run.
These should be bypassed in a fast algorithm for non-orthogonal multi-determinant wavefunctions.
The relevant timer output for: single Slater determinant
WaveFunction:psi0::VGL 71.2145 4.0097 13024713 0.000005468
J1OrbitalSoA:lr::VGL 4.7487 4.7487 13024713 0.000000365
J1OrbitalSoA:sr::VGL 2.8298 2.8298 13024713 0.000000217
SlaterDet::VGL 54.1923 1.5460 13024713 0.000004161
DiracDeterminant::ratio 1.8999 1.8999 13024713 0.000000146
DiracDeterminant::spovgl 50.7464 50.7464 6512313 0.000007792
2-det msd_fast
WaveFunction:psi0::VGL 1280.3559 4.5288 13024698 0.000098302
J1OrbitalSoA:lr::VGL 4.9064 4.9064 13024698 0.000000377
J1OrbitalSoA:sr::VGL 2.8862 2.8862 13024698 0.000000222
MultiSlaterDetTableMethod::VGL 1262.4257 1.2613 13024698 0.000096926
MultiSlaterDetTableMethod::evalGrad 506.1446 31.5392 6512400 0.000077720
MultiDiracDeterminant::calcGradRatios 474.6054 4.3001 19537200 0.000024292
MultiDiracDeterminant::buildTable 314.5037 314.5037 19537200 0.000016098
MultiDiracDeterminant::table2ratios 155.8015 155.8015 19537200 0.000007975
MultiSlaterDetTableMethod::ratioGrad 755.0197 1.0360 6512298 0.000115938
MultiDiracDeterminant::evaluateDetAndGrad 753.9838 4.0023 6512298 0.000115778
MultiDiracDeterminant::calcGradRatios 472.7771 3.4091 19536894 0.000024199
MultiDiracDeterminant::buildTable 314.8147 314.8147 19536894 0.000016114
MultiDiracDeterminant::table2ratios 154.5534 154.5534 19536894 0.000007911
MultiDiracDeterminant::calcRatios 157.6205 1.3158 6512298 0.000024204
MultiDiracDeterminant::buildTable 104.6650 104.6650 6512298 0.000016072
MultiDiracDeterminant::table2ratios 51.6398 51.6398 6512298 0.000007930
MultiDiracDeterminant::evalOrbVGL 80.3637 80.3637 6512298 0.000012340
MultiDiracDeterminant::updateRefDetInv 39.2201 39.2201 26049192 0.000001506
1-det msd_fast
WaveFunction:psi0::VGL 191.4760 4.3395 13024707 0.000014701
J1OrbitalSoA:lr::VGL 4.7652 4.7652 13024707 0.000000366
J1OrbitalSoA:sr::VGL 2.7720 2.7720 13024707 0.000000213
MultiSlaterDetTableMethod::VGL 173.8985 1.2734 13024707 0.000013351
MultiSlaterDetTableMethod::evalGrad 33.7062 29.7036 6512400 0.000005176
MultiDiracDeterminant::calcGradRatios 4.0026 2.8915 19537200 0.000000205
MultiDiracDeterminant::buildTable 0.4950 0.4950 19537200 0.000000025
MultiDiracDeterminant::table2ratios 0.6161 0.6161 19537200 0.000000032
MultiSlaterDetTableMethod::ratioGrad 138.9189 0.9109 6512307 0.000021332
MultiDiracDeterminant::evaluateDetAndGrad 138.0080 3.3489 6512307 0.000021192
MultiDiracDeterminant::calcGradRatios 3.6260 2.5791 19536921 0.000000186
MultiDiracDeterminant::buildTable 0.4839 0.4839 19536921 0.000000025
MultiDiracDeterminant::table2ratios 0.5630 0.5630 19536921 0.000000029
MultiDiracDeterminant::calcRatios 1.2826 0.8845 6512307 0.000000197
MultiDiracDeterminant::buildTable 0.1730 0.1730 6512307 0.000000027
MultiDiracDeterminant::table2ratios 0.2252 0.2252 6512307 0.000000035
MultiDiracDeterminant::evalOrbVGL 92.8755 92.8755 6512307 0.000014262
MultiDiracDeterminant::updateRefDetInv 36.8749 36.8749 26049228 0.000001416
After a close look. I figured that this example exercised the worst case of table method. It needs to excite 27 electrons.
Number of terms in pairs array: 0 # 1 det
Number of terms in pairs array: 729 # 2 dets, 27x27
buildTable fills the 27x27 matrices. The current algorithm assumes the matrix is sparse and uses a bunch of dot products. In your case, it can be done using GEMM and the algorithm remains generic.
table2ratios uses LAPACK to solve the LU of 27x27 and get the determinant value. I believe there is not much we can do.
Right, this is a hacky way to combine two non-orthogonal determinants. Can we add a flag to circumvent the table method? Since all the SPOs are evaluated, we just need to build two single determinants using the usual route, then add them.
Time on buildTable can be reduced by adding an option but table2ratios cannot.
circumventing the table method requires adding a class built on top of single dets.