Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions source/source_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ add_library(
${objects}
)

add_executable(
MODULE_PW_simd_bench
test_serial/pw_simd_bench.cpp
)

target_link_libraries(
MODULE_PW_simd_bench
parameter
${math_libs}
planewave
device
base
Threads::Threads
)

if(USE_OPENMP)
target_link_libraries(MODULE_PW_simd_bench OpenMP::OpenMP_CXX)
endif()

if (USE_DSP)
target_link_libraries(planewave PRIVATE
${MTBLAS_FFT_DIR}/libmtblas/lib/libmtfft.a)
Expand Down
109 changes: 73 additions & 36 deletions source/source_basis/module_pw/pw_gatherscatter.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,49 @@
#include "pw_basis.h"
#include "source_base/global_function.h"
#include "source_base/timer.h"
#include <algorithm>
#include <typeinfo>

namespace ModulePW
{
namespace detail
{
template <typename T>
inline void copy_complex_buffer(const std::complex<T>* in, std::complex<T>* out, const int count)
{
if (count <= 0)
{
return;
}

std::copy_n(in, count, out);
}

// Top-level transform copies own the OpenMP parallel region; gather/scatter
// loops call the non-parallel helper inside their existing parallel regions.
template <typename T>
inline void copy_complex_buffer_parallel(const std::complex<T>* in, std::complex<T>* out, const int count)
{
constexpr int chunk_size = 1024;
if (count <= chunk_size)
{
copy_complex_buffer(in, out, count);
return;
}

#ifdef _OPENMP
#pragma omp parallel for schedule(static)
for (int offset = 0; offset < count; offset += chunk_size)
{
const int chunk_count = std::min(chunk_size, count - offset);
std::copy_n(in + offset, chunk_count, out + offset);
}
#else
copy_complex_buffer(in, out, count);
#endif
}
} // namespace detail

/**
* @brief gather planes and scatter sticks
* @param in: (nplane,fftny,fftnx)
Expand All @@ -21,19 +60,18 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
const int nst_ = this->nst;
const int nz_ = this->nz;
const int* istot2ixy_ = this->istot2ixy;
ModuleBase::timer::start(this->classname, "gatherp_copy_serial");
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(int is = 0 ; is < nst_ ; ++is)
{
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[is*nz_];
std::complex<T> *inp = &in[ixy*nz_];
for(int iz = 0 ; iz < nz_ ; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[is*nz_];
const std::complex<T>* inp = &in[ixy*nz_];
detail::copy_complex_buffer(inp, outp, nz_);
}
ModuleBase::timer::end(this->classname, "gatherp_copy_serial");
return;
}

Expand All @@ -44,19 +82,18 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
const int nstot_gps = this->nstot;
const int nplane_gps = this->nplane;
const int* istot2ixy_gps = this->istot2ixy;
ModuleBase::timer::start(this->classname, "gatherp_copy_pack");
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int istot = 0; istot < nstot_gps; ++istot)
{
int ixy = istot2ixy_gps[istot];
std::complex<T> *outp = &out[istot * nplane_gps];
std::complex<T> *inp = &in[ixy * nplane_gps];
for (int iz = 0; iz < nplane_gps; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[istot * nplane_gps];
const std::complex<T>* inp = &in[ixy * nplane_gps];
detail::copy_complex_buffer(inp, outp, nplane_gps);
}
ModuleBase::timer::end(this->classname, "gatherp_copy_pack");

//exchange data
//(nplane,nstot) to (numz[ip],ns, poolnproc)
Expand All @@ -80,6 +117,7 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
const int* numz_gps = this->numz;
const int* startg_gps = this->startg;
const int* startz_gps = this->startz;
ModuleBase::timer::start(this->classname, "gatherp_copy_unpack");
#ifdef _OPENMP
#pragma omp parallel for collapse(2)
#endif
Expand All @@ -90,14 +128,12 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
int nzip = numz_gps[ip];
std::complex<T> *outp0 = &out[startz_gps[ip]];
std::complex<T> *inp0 = &in[startg_gps[ip]];
std::complex<T> *outp = &outp0[is * nz_gps];
std::complex<T> *inp = &inp0[is * nzip ];
for (int izip = 0; izip < nzip; ++izip)
{
outp[izip] = inp[izip];
}
std::complex<T>* outp = &outp0[is * nz_gps];
const std::complex<T>* inp = &inp0[is * nzip ];
detail::copy_complex_buffer(inp, outp, nzip);
}
}
ModuleBase::timer::end(this->classname, "gatherp_copy_unpack");
#endif
return;
}
Expand All @@ -118,27 +154,28 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
const int nst_ = this->nst;
const int nz_ = this->nz;
const int* istot2ixy_ = this->istot2ixy;
ModuleBase::timer::start(this->classname, "gathers_zero_serial");
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for(int i = 0; i < nrxx_; ++i)
{
out[i] = std::complex<T>(0, 0);
}
ModuleBase::timer::end(this->classname, "gathers_zero_serial");

ModuleBase::timer::start(this->classname, "gathers_copy_serial");
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(int is = 0 ; is < nst_ ; ++is)
{
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[ixy*nz_];
std::complex<T> *inp = &in[is*nz_];
for(int iz = 0 ; iz < nz_ ; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[ixy*nz_];
const std::complex<T>* inp = &in[is*nz_];
detail::copy_complex_buffer(inp, outp, nz_);
}
ModuleBase::timer::end(this->classname, "gathers_copy_serial");
return;
}

Expand All @@ -152,6 +189,7 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
const int* numz_ = this->numz;
const int* startg_ = this->startg;
const int* startz_ = this->startz;
ModuleBase::timer::start(this->classname, "gathers_copy_pack");
#ifdef _OPENMP
#pragma omp parallel for collapse(2)
#endif
Expand All @@ -162,14 +200,12 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
int nzip = numz_[ip];
std::complex<T> *outp0 = &out[startg_[ip]];
std::complex<T> *inp0 = &in[startz_[ip]];
std::complex<T> *outp = &outp0[is * nzip];
std::complex<T> *inp = &inp0[is * nz_ ];
for (int izip = 0; izip < nzip; ++izip)
{
outp[izip] = inp[izip];
}
std::complex<T>* outp = &outp0[is * nzip];
const std::complex<T>* inp = &inp0[is * nz_ ];
detail::copy_complex_buffer(inp, outp, nzip);
}
}
ModuleBase::timer::end(this->classname, "gathers_copy_pack");

//exchange data
//(numz[ip],ns, poolnproc) to (nplane,nstot)
Expand All @@ -187,31 +223,32 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
}

const int nrxx_gsp = this->nrxx;
ModuleBase::timer::start(this->classname, "gathers_zero_mpi");
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for(int i = 0; i < nrxx_gsp; ++i)
{
out[i] = std::complex<T>(0, 0);
}
ModuleBase::timer::end(this->classname, "gathers_zero_mpi");
//change (nplane,nstot) to (nplane fftnxy)
const int nstot = this->nstot;
const int nplane = this->nplane;
const int* istot2ixy = this->istot2ixy;
ModuleBase::timer::start(this->classname, "gathers_copy_unpack");
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int istot = 0;istot < nstot; ++istot)
{
int ixy = istot2ixy[istot];
//int ixy = (ixy / fftny)*ny + ixy % fftny;
std::complex<T> *outp = &out[ixy * nplane];
std::complex<T> *inp = &in[istot * nplane];
for (int iz = 0; iz < nplane; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[ixy * nplane];
const std::complex<T>* inp = &in[istot * nplane];
detail::copy_complex_buffer(inp, outp, nplane);
}
ModuleBase::timer::end(this->classname, "gathers_copy_unpack");
#endif
return;
}
Expand Down
18 changes: 3 additions & 15 deletions source/source_basis/module_pw/pw_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
const int npw_ = this->npw;
const int nxyz_ = this->nxyz;
const int* ig2isz_ = this->ig2isz;
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < nrxx_; ++ir)
{
this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = in[ir];
}
detail::copy_complex_buffer_parallel(in, this->fft_bundle.get_auxr_data<FPTYPE>(), nrxx_);
this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());

this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
Expand Down Expand Up @@ -199,13 +193,7 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
}
else
{
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < nrxx_; ++ir)
{
out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir];
}
detail::copy_complex_buffer_parallel(this->fft_bundle.get_auxr_data<FPTYPE>(), out, nrxx_);
}
ModuleBase::timer::end(this->classname, "recip2real");
}
Expand Down Expand Up @@ -340,4 +328,4 @@ template void PW_Basis::recip2real<double>(const std::complex<double>* in,
std::complex<double>* out,
const bool add,
const double factor) const;
} // namespace ModulePW
} // namespace ModulePW
16 changes: 2 additions & 14 deletions source/source_basis/module_pw/pw_transform_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,7 @@ void PW_Basis_K::real2recip(const std::complex<FPTYPE>* in,

assert(this->gamma_only == false);
auto* auxr = this->fft_bundle.get_auxr_data<FPTYPE>();
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < this->nrxx; ++ir)
{
auxr[ir] = in[ir];
}
detail::copy_complex_buffer_parallel(in, auxr, this->nrxx);
this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());

this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
Expand Down Expand Up @@ -200,13 +194,7 @@ void PW_Basis_K::recip2real(const std::complex<FPTYPE>* in,
}
else
{
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < this->nrxx; ++ir)
{
out[ir] = auxr[ir];
}
detail::copy_complex_buffer_parallel(auxr, out, this->nrxx);
}
ModuleBase::timer::end(this->classname, "recip2real");
}
Expand Down
Loading
Loading