Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 7 additions & 24 deletions source/source_basis/module_pw/pw_gatherscatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "source_base/global_function.h"
#include "source_base/timer.h"
#include <typeinfo>
#include <cstring>

namespace ModulePW
{
Expand Down Expand Up @@ -29,10 +30,7 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[is*nz_];
std::complex<T> *inp = &in[ixy*nz_];
for(int iz = 0 ; iz < nz_ ; ++iz)
{
outp[iz] = inp[iz];
}
std::memcpy(outp, inp, nz_ * sizeof(std::complex<T>));
}
return;
}
Expand All @@ -52,10 +50,7 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
int ixy = istot2ixy_gps[istot];
std::complex<T> *outp = &out[istot * nplane_gps];
std::complex<T> *inp = &in[ixy * nplane_gps];
for (int iz = 0; iz < nplane_gps; ++iz)
{
outp[iz] = inp[iz];
}
std::memcpy(outp, inp, nplane_gps * sizeof(std::complex<T>));
}

//exchange data
Expand Down Expand Up @@ -92,10 +87,7 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
std::complex<T> *inp0 = &in[startg_gps[ip]];
std::complex<T> *outp = &outp0[is * nz_gps];
std::complex<T> *inp = &inp0[is * nzip ];
for (int izip = 0; izip < nzip; ++izip)
{
outp[izip] = inp[izip];
}
std::memcpy(outp, inp, nzip * sizeof(std::complex<T>));
}
}
#endif
Expand Down Expand Up @@ -134,10 +126,7 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[ixy*nz_];
std::complex<T> *inp = &in[is*nz_];
for(int iz = 0 ; iz < nz_ ; ++iz)
{
outp[iz] = inp[iz];
}
std::memcpy(outp, inp, nz_ * sizeof(std::complex<T>));
}
return;
}
Expand All @@ -164,10 +153,7 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
std::complex<T> *inp0 = &in[startz_[ip]];
std::complex<T> *outp = &outp0[is * nzip];
std::complex<T> *inp = &inp0[is * nz_ ];
for (int izip = 0; izip < nzip; ++izip)
{
outp[izip] = inp[izip];
}
std::memcpy(outp, inp, nzip * sizeof(std::complex<T>));
}
}

Expand Down Expand Up @@ -207,10 +193,7 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
//int ixy = (ixy / fftny)*ny + ixy % fftny;
std::complex<T> *outp = &out[ixy * nplane];
std::complex<T> *inp = &in[istot * nplane];
for (int iz = 0; iz < nplane; ++iz)
{
outp[iz] = inp[iz];
}
std::memcpy(outp, inp, nplane * sizeof(std::complex<T>));
}
#endif
return;
Expand Down
91 changes: 71 additions & 20 deletions source/source_basis/module_pw/pw_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,30 @@ namespace ModulePW
// return default_device_cpu;
// }
/**
* @brief transform real space to reciprocal space
* @details c(g)=\int dr*f(r)*exp(-ig*r)
* Here we calculate c(g)
* @param in: (nplane,ny,nx), std::complex<double> data
* @param out: (nz, ns), std::complex<double> data
* @brief Transform real-space data to reciprocal-space plane-wave coefficients (complex input).
* @details
* Performs the forward 3D FFT with MPI parallel transposition for non-gamma-only calculations.
* Computes: c(g) = (1/N) * sum_r f(r) * exp(-i g·r)
*
* This is the #1 hotspot function in ABACUS, accounting for 22-30% of total SCF runtime.
* The implementation uses a 2D domain decomposition strategy:
* 1. Copy input real-space data to FFT buffer (z-slab distributed, O(nrxx))
* 2. In-place 2D FFT on each xy-plane (fftxyfor, independent per process)
* 3. MPI_Alltoallv transposition: xy-planes → z-sticks (gatherp_scatters)
* - Communication volume: O(nst_per * nz * sizeof(complex))
* 4. In-place 1D FFT along each z-stick (fftzfor)
* 5. Extract plane-wave coefficients: out[ig] = auxg[ig2isz[ig]] / nxyz
*
* @tparam FPTYPE Floating-point precision (float or double)
* @param in Input real-space array, shape (nplane, ny, nx) in z-slab distribution
* Each MPI process holds nplane xy-planes, for nrxx = nplane*nx*ny local elements
* @param out Output reciprocal-space array, shape (npw) — plane-wave coefficients
* Only stores coefficients for G-vectors on this process (stick distribution)
* @param add If true, add scaled result to existing out[]; if false, overwrite out[]
* @param factor Scaling factor applied when add=true: out[ig] += factor * c(g)
* @note The 1/nxyz normalization is always applied regardless of add/factor
* @note For gamma-only calculations, use the real-input overload (r2c FFT path)
* @see recip2real() for the inverse transform, gatherp_scatters() for MPI communication
*/
template <typename FPTYPE>
void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
Expand Down Expand Up @@ -73,11 +92,20 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
}

/**
* @brief transform real space to reciprocal space
* @details c(g)=\int dr*f(r)*exp(-ig*r)
* Here we calculate c(g)
* @param in: (nplane,ny,nx), double data
* @param out: (nz, ns), std::complex<double> data
* @brief Transform real-valued real-space data to reciprocal-space (gamma-only or non-gamma).
* @details
* Two code paths depending on gamma_only flag:
* - gamma_only=true: Uses r2c FFT (fftxyr2c). Only half the FFT grid is stored (fftnx = nx/2+1),
* exploiting Hermitian symmetry to save ~50% memory and computation.
* - gamma_only=false: Converts real input to complex, then follows the same 3D FFT path as
* the complex-input overload (fftxyfor → gatherp_scatters → fftzfor).
*
* @tparam FPTYPE Floating-point precision (float or double)
* @param in Input real-space array (real-valued), shape (nplane, ny, nx) in z-slab distribution
* @param out Output reciprocal-space plane-wave coefficients (complex)
* @param add If true, accumulate scaled result into out[]; if false, overwrite
* @param factor Scaling factor for add mode
* @see real2recip(const std::complex<FPTYPE>*, ...) for the complex-input variant
*/
template <typename FPTYPE>
void PW_Basis::real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, const bool add, const FPTYPE factor) const
Expand Down Expand Up @@ -147,11 +175,25 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, const boo
}

/**
* @brief transform reciprocal space to real space
* @details f(r)=1/V * \sum_{g} c(g)*exp(ig*r)
* Here we calculate f(r)
* @param in: (nz,ns), std::complex<double>
* @param out: (nplane, ny, nx), std::complex<double>
* @brief Transform reciprocal-space plane-wave coefficients to real-space data (complex output).
* @details
* Performs the inverse 3D FFT — the reverse of real2recip():
* f(r) = sum_g c(g) * exp(i g·r)
*
* Algorithm (reverse of real2recip):
* 1. Zero-fill the FFT stick buffer (nst*nz elements), then scatter: auxg[ig2isz[ig]] = in[ig]
* 2. Backward 1D FFT along each z-stick (fftzbac)
* 3. MPI_Alltoallv transposition: sticks → xy-planes (gathers_scatterp, reverse direction)
* 4. Backward 2D FFT on each xy-plane (fftxybac)
* 5. Copy/extract real-space result: out[ir] = auxr[ir]
*
* @tparam FPTYPE Floating-point precision (float or double)
* @param in Input reciprocal-space array, shape (npw) — plane-wave coefficients in stick distribution
* @param out Output real-space array, shape (nplane, ny, nx) in z-slab distribution
* @param add If true, add scaled result to existing out[]; if false, overwrite
* @param factor Scaling factor for add mode: out[ir] += factor * f(r)
* @note No 1/nxyz normalization factor is applied (unlike real2recip)
* @see real2recip() for the forward transform, gathers_scatterp() for MPI communication
*/
template <typename FPTYPE>
void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
Expand Down Expand Up @@ -211,11 +253,20 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
}

/**
* @brief transform reciprocal space to real space
* @details f(r)=1/V * \sum_{g} c(g)*exp(ig*r)
* Here we calculate f(r)
* @param in: (nz,ns), std::complex<double>
* @param out: (nplane, ny, nx), double
* @brief Transform reciprocal-space to real-valued real-space (gamma-only or non-gamma).
* @details
* Two code paths:
* - gamma_only=true: Uses c2r FFT (fftxyc2r) to exploit Hermitian symmetry. After backward 1D FFT
* and MPI transposition, applies c2r FFT producing real-valued output directly.
* - gamma_only=false: Follows the standard complex path (fftzbac → gathers_scatterp → fftxybac),
* then extracts the real part of the complex result.
*
* @tparam FPTYPE Floating-point precision (float or double)
* @param in Input reciprocal-space plane-wave coefficients (complex)
* @param out Output real-space array (real-valued)
* @param add If true, accumulate scaled result into out[]; if false, overwrite
* @param factor Scaling factor for add mode
* @see recip2real(const std::complex<FPTYPE>*, std::complex<FPTYPE>*, ...) for complex output
*/
template <typename FPTYPE>
void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const bool add, const FPTYPE factor) const
Expand Down
40 changes: 34 additions & 6 deletions source/source_hamilt/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,42 @@ class Operator
Operator();
virtual ~Operator();

// this is the core function for Operator
// do H|psi> from input |psi> ,
/// as default, different operators donate hPsi independently
/// run this->act function for the first operator and run all act() for other nodes in chain table
/// if this procedure is not suitable for your operator, just override this function.
/// output of hpsi would be first member of the returned tuple
/// @brief input type of hPsi: (psi_input, range, hpsi_pointer)
/// @details
/// hpsi_info bundles the input and output of hPsi():
/// - std::get<0>: pointer to psi::Psi<T, Device> (input wavefunction)
/// - std::get<1>: psi::Range specifying which bands to operate on
/// - std::get<2>: T* pointer to output hpsi buffer (can equal psi_in for in-place)
typedef std::tuple<const psi::Psi<T, Device>*, const psi::Range, T*> hpsi_info;

/// @brief Core hot-path: compute H|psi> by traversing the operator chain.
/// @details
/// This is the central computational kernel of ABACUS, called O(n_bands * n_iter * n_kpoints)
/// times during a single SCF calculation. It accounts for 18-25% of total runtime.
///
/// Algorithm:
/// 1. Unwrap the input hpsi_info to get psi, band range, and output buffer
/// 2. Allocate or reuse the temporary hpsi buffer via get_hpsi()
/// 3. Call act() on the first operator node (is_first_node=true) -- this node zeros hpsi
/// 4. Iterate through next_op linked list, calling act() on each subsequent node (is_first_node=false)
/// Each node accumulates its contribution: hpsi += O|psi>
/// 5. If in_place mode, copy temporary hpsi back to the caller-provided buffer
/// 6. Return wrapped hpsi_info for downstream use
///
/// The operator chain typically includes (in order):
/// Ekinetic → Veff → Nonlocal → Meta → OnsiteProj
///
/// @param input hpsi_info tuple: (psi_input, band_range, hpsi_output_pointer)
/// - psi_input: the wavefunction Psi object
/// - band_range: which bands to compute (range_1 to range_2 inclusive)
/// - hpsi_output_pointer: pre-allocated buffer for H|psi> result.
/// If equal to psi_input pointer, in_place mode is used (temporary buffer allocated internally).
/// @return hpsi_info containing (internal_hpsi, range, caller_hpsi_pointer)
/// @note This function is performance-critical. The operator chain traversal is the innermost
/// loop of iterative diagonalization methods (CG, Davidson, BPCG).
/// @note For PW calculations, each act() call may involve FFT transforms (Veff),
/// BLAS3 gemm operations (Nonlocal), or element-wise vector ops (Ekinetic).
/// @see Operator::act(), HamiltPW, DiagoCG::diag()
virtual hpsi_info hPsi(hpsi_info& input) const;

virtual void init(const int ik_in);
Expand Down
Loading