LCOV - code coverage report
Current view: top level - linalg - vector.hpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 47.7 % 44 21
Test Date: 2025-10-23 22:45:05 Functions: 40.0 % 30 12
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #ifndef PALACE_LINALG_VECTOR_HPP
       5              : #define PALACE_LINALG_VECTOR_HPP
       6              : 
       7              : #include <complex>
       8              : #include <vector>
       9              : #include <mfem.hpp>
      10              : #include "utils/communication.hpp"
      11              : 
      12              : namespace palace
      13              : {
      14              : 
      15              : //
      16              : // Functionality extending mfem::Vector from MFEM, including basic functions for parallel
      17              : // vectors distributed across MPI processes.
      18              : //
      19              : 
      20              : using Vector = mfem::Vector;
      21              : 
      22              : // A complex-valued vector represented as two real vectors, one for each component.
      23        61941 : class ComplexVector
      24              : {
      25              : private:
      26              :   Vector xr, xi;
      27              : 
      28              : public:
      29              :   // Create a vector with the given size.
      30              :   ComplexVector(int size = 0);
      31              : 
      32              :   // Copy constructor.
      33              :   ComplexVector(const ComplexVector &y);
      34              : 
      35              :   // Copy constructor from separately provided real and imaginary parts.
      36              :   ComplexVector(const Vector &yr, const Vector &yi);
      37              : 
      38              :   // Copy constructor from an array of complex values.
      39              :   ComplexVector(const std::complex<double> *py, int size, bool on_dev);
      40              : 
      41              :   // Create a vector referencing the memory of another vector, at the given base offset and
      42              :   // size.
      43              :   ComplexVector(Vector &y, int offset, int size);
      44              : 
      45              :   // Flag for runtime execution on the mfem::Device. See the documentation for mfem::Vector.
      46              :   void UseDevice(bool use_dev);
      47              :   bool UseDevice() const { return xr.UseDevice(); }
      48              : 
      49              :   // Return the size of the vector.
      50              :   int Size() const { return xr.Size(); }
      51              : 
      52              :   // Set the size of the vector. See the notes for Vector::SetSize for behavior in the cases
      53              :   // where the new size is less than or greater than Size() or Capacity().
      54              :   void SetSize(int size);
      55              : 
      56              :   // Set this vector to reference the memory of another vector, at the given base offset and
      57              :   // size.
      58              :   void MakeRef(Vector &y, int offset, int size);
      59              : 
      60              :   // Get access to the real and imaginary vector parts.
      61           33 :   const Vector &Real() const { return xr; }
      62           66 :   Vector &Real() { return xr; }
      63           33 :   const Vector &Imag() const { return xi; }
      64           60 :   Vector &Imag() { return xi; }
      65              : 
      66              :   // Set from a ComplexVector, without resizing.
      67              :   void Set(const ComplexVector &y);
      68              :   ComplexVector &operator=(const ComplexVector &y)
      69              :   {
      70            0 :     Set(y);
      71            0 :     return *this;
      72              :   }
      73              : 
      74              :   // Set from separately provided real and imaginary parts, without resizing.
      75              :   void Set(const Vector &yr, const Vector &yi);
      76              : 
      77              :   // Set from an array of complex values, without resizing.
      78              :   void Set(const std::complex<double> *py, int size, bool on_dev);
      79              : 
      80              :   // Copy the vector into an array of complex values.
      81              :   void Get(std::complex<double> *py, int size, bool on_dev) const;
      82              : 
      83              :   // Set all entries equal to s.
      84              :   ComplexVector &operator=(std::complex<double> s);
      85              :   ComplexVector &operator=(double s)
      86              :   {
      87            9 :     *this = std::complex<double>(s, 0.0);
      88            0 :     return *this;
      89              :   }
      90              : 
      91              :   // Set the vector from an array of blocks and coefficients, without resizing.
      92              :   void SetBlocks(const std::vector<const ComplexVector *> &y,
      93              :                  const std::vector<std::complex<double>> &s);
      94              : 
      95              :   // Scale all entries by s.
      96              :   ComplexVector &operator*=(std::complex<double> s);
      97              : 
      98              :   // Replace entries with their complex conjugate.
      99              :   void Conj();
     100              : 
     101              :   // Replace entries with their absolute value.
     102              :   void Abs();
     103              : 
     104              :   // Set all entries to their reciprocal.
     105              :   void Reciprocal();
     106              : 
     107              :   // Vector dot product (yᴴ x) or indefinite dot product (yᵀ x) for complex vectors.
     108              :   std::complex<double> Dot(const ComplexVector &y) const;
     109              :   std::complex<double> TransposeDot(const ComplexVector &y) const;
     110            0 :   std::complex<double> operator*(const ComplexVector &y) const { return Dot(y); }
     111              : 
     112              :   // In-place addition (*this) += alpha * x.
     113              :   void AXPY(std::complex<double> alpha, const ComplexVector &x);
     114            0 :   void Add(std::complex<double> alpha, const ComplexVector &x) { AXPY(alpha, x); }
     115              :   void Subtract(std::complex<double> alpha, const ComplexVector &x) { AXPY(-alpha, x); }
     116              :   ComplexVector &operator+=(const ComplexVector &x)
     117              :   {
     118            0 :     AXPY(1.0, x);
     119            0 :     return *this;
     120              :   }
     121              :   ComplexVector &operator-=(const ComplexVector &x)
     122              :   {
     123            3 :     AXPY(-1.0, x);
     124              :     return *this;
     125              :   }
     126              : 
     127              :   // In-place addition (*this) = alpha * x + beta * (*this).
     128              :   void AXPBY(std::complex<double> alpha, const ComplexVector &x, std::complex<double> beta);
     129              : 
     130              :   // In-place addition (*this) = alpha * x + beta * y + gamma * (*this).
     131              :   void AXPBYPCZ(std::complex<double> alpha, const ComplexVector &x,
     132              :                 std::complex<double> beta, const ComplexVector &y,
     133              :                 std::complex<double> gamma);
     134              : 
     135              :   static void AXPY(std::complex<double> alpha, const Vector &xr, const Vector &xi,
     136              :                    Vector &yr, Vector &yi);
     137              : 
     138              :   static void AXPBY(std::complex<double> alpha, const Vector &xr, const Vector &xi,
     139              :                     std::complex<double> beta, Vector &yr, Vector &yi);
     140              : 
     141              :   static void AXPBYPCZ(std::complex<double> alpha, const Vector &xr, const Vector &xi,
     142              :                        std::complex<double> beta, const Vector &yr, const Vector &yi,
     143              :                        std::complex<double> gamma, Vector &zr, Vector &zi);
     144              : };
     145              : 
     146              : // A stack-allocated vector with compile-time fixed size.
     147              : //
     148              : // StaticVector provides a Vector interface backed by stack memory instead of
     149              : // heap allocation. The size N is fixed at compile time, making it suitable for
     150              : // small vectors where performance and avoiding dynamic allocation are
     151              : // important.
     152              : //
     153              : // Template parameters:
     154              : // - N: The fixed size of the vector (number of elements)
     155              : //
     156              : // Notes:
     157              : // - Inherits from mfem::Vector, so can be used anywhere Vector is expected.
     158              : // - Memory is automatically managed (no new/delete needed).
     159              : // - Faster than dynamic Vector for small sizes due to stack allocation.
     160              : //
     161              : // Example usage:
     162              : //
     163              : // StaticVector<3> vec;  // 3D vector on stack
     164              : // vec[0] = 1.0;
     165              : // vec[1] = 2.0;
     166              : // vec[2] = 3.0;
     167              : //
     168              : // vec.Sum();
     169              : //
     170              : // You can also create StaticComplexVectors:
     171              : //
     172              : // StaticVector<3> vec_real, vec_imag;
     173              : // ComplexVector complex_vec(vec_real, vec_imag);
     174              : template <int N>
     175              : class StaticVector : public Vector
     176              : {
     177              : private:
     178              :   double buff[N];
     179              : 
     180              : public:
     181       115214 :   StaticVector() : Vector() { SetDataAndSize(buff, N); }
     182              : 
     183            0 :   ~StaticVector()
     184              :   {
     185              :     MFEM_ASSERT(GetData() == buff,
     186              :                 "Buffer of StaticVector changed. This indicates a possible bug.");
     187              :     MFEM_ASSERT(Size() == N, "Size of StaticVector changed. This indicates a possible bug.")
     188         9609 :   }
     189              : 
     190              :   using Vector::operator=;  // Extend the implicitly defined assignment operators
     191              : };
     192              : 
     193              : namespace linalg
     194              : {
     195              : 
     196              : // Returns the global vector size.
     197              : template <typename VecType>
     198              : inline HYPRE_BigInt GlobalSize(MPI_Comm comm, const VecType &x)
     199              : {
     200              :   HYPRE_BigInt N = x.Size();
     201              :   Mpi::GlobalSum(1, &N, comm);
     202              :   return N;
     203              : }
     204              : 
     205              : // Returns the global vector size for two vectors.
     206              : template <typename VecType1, typename VecType2>
     207            0 : inline std::pair<HYPRE_BigInt, HYPRE_BigInt> GlobalSize2(MPI_Comm comm, const VecType1 &x1,
     208              :                                                          const VecType2 &x2)
     209              : {
     210            0 :   HYPRE_BigInt N[2] = {x1.Size(), x2.Size()};
     211              :   Mpi::GlobalSum(2, N, comm);
     212            0 :   return {N[0], N[1]};
     213              : }
     214              : 
     215              : // Sets all entries of the vector corresponding to the given indices to the given (real)
     216              : // value or vector of values.
     217              : template <typename VecType>
     218              : void SetSubVector(VecType &x, const mfem::Array<int> &rows, double s);
     219              : template <typename VecType>
     220              : void SetSubVector(VecType &x, const mfem::Array<int> &rows, const VecType &y);
     221              : 
     222              : // Sets contiguous entries from start to the given vector.
     223              : template <typename VecType>
     224              : void SetSubVector(VecType &x, int start, const VecType &y);
     225              : 
     226              : // Sets all entries in the range [start, end) to  the given value.
     227              : template <typename VecType>
     228              : void SetSubVector(VecType &x, int start, int end, double s);
     229              : 
     230              : // Sets all entries of the vector to random numbers sampled from the [-1, 1] or [-1 - 1i,
     231              : // 1 + 1i] for complex-valued vectors.
     232              : template <typename VecType>
     233              : void SetRandom(MPI_Comm comm, VecType &x, int seed = 0);
     234              : template <typename VecType>
     235              : void SetRandomReal(MPI_Comm comm, VecType &x, int seed = 0);
     236              : template <typename VecType>
     237              : void SetRandomSign(MPI_Comm comm, VecType &x, int seed = 0);
     238              : 
     239              : // Calculate the local inner product yᴴ x or yᵀ x.
     240              : double LocalDot(const Vector &x, const Vector &y);
     241              : std::complex<double> LocalDot(const ComplexVector &x, const ComplexVector &y);
     242              : 
     243              : // Calculate the parallel inner product yᴴ x or yᵀ x.
     244              : template <typename VecType>
     245            0 : inline auto Dot(MPI_Comm comm, const VecType &x, const VecType &y)
     246              : {
     247            0 :   auto dot = LocalDot(x, y);
     248              :   Mpi::GlobalSum(1, &dot, comm);
     249            0 :   return dot;
     250              : }
     251              : 
     252              : // Calculate the vector 2-norm.
     253              : template <typename VecType>
     254            0 : inline auto Norml2(MPI_Comm comm, const VecType &x)
     255              : {
     256            0 :   return std::sqrt(std::abs(Dot(comm, x, x)));
     257              : }
     258              : 
     259              : // Normalize the vector, possibly with respect to an SPD matrix B.
     260              : template <typename VecType>
     261              : inline auto Normalize(MPI_Comm comm, VecType &x)
     262              : {
     263              :   auto norm = Norml2(comm, x);
     264              :   MFEM_ASSERT(norm > 0.0, "Zero vector norm in normalization!");
     265              :   x *= 1.0 / norm;
     266              :   return norm;
     267              : }
     268              : 
     269              : // Calculate the local sum of all elements in the vector.
     270              : double LocalSum(const Vector &x);
     271              : std::complex<double> LocalSum(const ComplexVector &x);
     272              : 
     273              : // Calculate the sum of all elements in the vector.
     274              : template <typename VecType>
     275            8 : inline auto Sum(MPI_Comm comm, const VecType &x)
     276              : {
     277            8 :   auto sum = LocalSum(x);
     278              :   Mpi::GlobalSum(1, &sum, comm);
     279            8 :   return sum;
     280              : }
     281              : 
     282              : // Calculate the mean of all elements in the vector.
     283              : template <typename VecType>
     284            0 : inline auto Mean(MPI_Comm comm, const VecType &x)
     285              : {
     286              :   using ScalarType = typename std::conditional<std::is_same<VecType, ComplexVector>::value,
     287              :                                                std::complex<double>, double>::type;
     288            0 :   ScalarType sum[2] = {LocalSum(x), ScalarType(x.Size())};
     289              :   Mpi::GlobalSum(2, sum, comm);
     290            0 :   return sum[0] / sum[1];
     291              : }
     292              : 
     293              : // Normalize a complex vector so its mean is on the positive real axis.
     294              : // Returns the original mean phase.
     295            0 : inline double NormalizePhase(MPI_Comm comm, ComplexVector &x)
     296              : {
     297            0 :   std::complex<double> mean = Mean(comm, x);
     298            0 :   x *= std::conj(mean) / std::abs(mean);
     299            0 :   return std::atan2(mean.imag(), mean.real());
     300              : }
     301              : 
     302              : // Addition y += alpha * x.
     303              : template <typename VecType, typename ScalarType>
     304              : void AXPY(ScalarType alpha, const VecType &x, VecType &y);
     305              : 
     306              : // Addition y = alpha * x + beta * y.
     307              : template <typename VecType, typename ScalarType>
     308              : void AXPBY(ScalarType alpha, const VecType &x, ScalarType beta, VecType &y);
     309              : 
     310              : // Addition z = alpha * x + beta * y + gamma * z.
     311              : template <typename VecType, typename ScalarType>
     312              : void AXPBYPCZ(ScalarType alpha, const VecType &x, ScalarType beta, const VecType &y,
     313              :               ScalarType gamma, VecType &z);
     314              : 
     315              : // Compute element-wise square root, optionally with scaling (multiplied before the square
     316              : // root).
     317              : void Sqrt(Vector &x, double s = 1.0);
     318              : 
     319              : // Compute the 3D Cartesian product between A and B and store the result in C.
     320              : // If add is true, accumulate the result to C instead of overwriting its
     321              : // content.
     322              : template <typename VecTypeA, typename VecTypeB, typename VecTypeC>
     323       142953 : void Cross3(const VecTypeA &A, const VecTypeB &B, VecTypeC &C, bool add = false)
     324              : {
     325       142953 :   if (add)
     326              :   {
     327         7777 :     C[0] += A[1] * B[2] - A[2] * B[1];
     328         7777 :     C[1] += A[2] * B[0] - A[0] * B[2];
     329         7777 :     C[2] += A[0] * B[1] - A[1] * B[0];
     330              :   }
     331              :   else
     332              :   {
     333       135176 :     C[0] = A[1] * B[2] - A[2] * B[1];
     334       135176 :     C[1] = A[2] * B[0] - A[0] * B[2];
     335       135176 :     C[2] = A[0] * B[1] - A[1] * B[0];
     336              :   }
     337       142953 : }
     338              : 
     339              : }  // namespace linalg
     340              : 
     341              : }  // namespace palace
     342              : 
     343              : #endif  // PALACE_LINALG_VECTOR_HPP
        

Generated by: LCOV version 2.0-1