abacus-develop icon indicating copy to clipboard operation
abacus-develop copied to clipboard

Duplicate functionality of sPsi functions

Open Cstandardlib opened this issue 1 year ago • 0 comments

Describe the Code Quality Issue

Issue Description

Upon reviewing the current codebase, I've noticed that there seems to be duplicate functionality in the implementation of the sPsi function. This issue aims to address the redundancy and streamline the code for better maintainability and readability.

Observations

Both spsi_func(in hsolver_pw.cpp) and HamiltPW::sPsi(in hamilt_pw.cpp) contain a similar conditional structure for checking this->use_uspp/GlobalV::use_uspp.

auto spsi_func = [this, hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
            if (this->use_uspp) {
                hm->sPsi(...);
            } else {
                base_device::memory::synchronize_memory_op<T, Device, Device>()(...);
            }
};

// while in sPsi definition
void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
                               T* spsi,         // spsi
                               const int nrow,  // dimension of spsi: nbands * nrow
                               const int npw,   // number of plane waves
                               const int nbands // number of bands
) const{
    if(GlobalV::use_paw)
    {
#ifdef USE_PAW
        for(int m = 0; m < nbands; m ++)
        {
            GlobalC::paw_cell.paw_nl_psi(1, reinterpret_cast<const std::complex<double>*> (&psi_in[m * npw]),
                reinterpret_cast<std::complex<double>*>(&spsi[m * nrow]));
        }
#endif
        return;
    }

    syncmem_op()(this->ctx, this->ctx, spsi, psi_in, static_cast<size_t>(nbands * nrow));
    if (GlobalV::use_uspp)
    { ... }
}

The logic for handling the USE_PAW symbol is nearly identical in Davidson code and sPsi implementation, suggesting potential duplication.

// sPsi are called here
for (int m = 0; m < notconv; m++)
    {
        if(this->use_paw)
        {
#ifdef USE_PAW
            GlobalC::paw_cell.paw_nl_psi(1,reinterpret_cast<const std::complex<double>*> (basis + dim*(nbase + m)),
                reinterpret_cast<std::complex<double>*>(&spsi[(nbase + m) * dim]));
#endif
        }
        else
        {
            spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
        }
    }

Additional Context

No response

Task list for Issue attackers (only for developers)

  • [ ] Identify the specific code file or section with the code quality issue.
  • [ ] Investigate the issue and determine the root cause.
  • [ ] Research best practices and potential solutions for the identified issue.
  • [ ] Refactor the code to improve code quality, following the suggested solution.
  • [ ] Ensure the refactored code adheres to the project's coding standards.
  • [ ] Test the refactored code to ensure it functions as expected.
  • [ ] Update any relevant documentation, if necessary.
  • [ ] Submit a pull request with the refactored code and a description of the changes made.

Cstandardlib avatar Aug 07 '24 12:08 Cstandardlib