LCOV - code coverage report
Current view: top level - linalg - operator.hpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 10.0 % 70 7
Test Date: 2025-10-23 22:45:05 Functions: 6.7 % 30 2
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #ifndef PALACE_LINALG_OPERATOR_HPP
       5              : #define PALACE_LINALG_OPERATOR_HPP
       6              : 
       7              : #include <complex>
       8              : #include <memory>
       9              : #include <type_traits>
      10              : #include <utility>
      11              : #include <vector>
      12              : #include "linalg/vector.hpp"
      13              : 
      14              : namespace palace
      15              : {
      16              : 
      17              : //
      18              : // Functionality extending mfem::Operator from MFEM.
      19              : //
      20              : 
      21              : using Operator = mfem::Operator;
      22              : 
      23              : // Abstract base class for complex-valued operators.
      24              : class ComplexOperator
      25              : {
      26              : protected:
      27              :   // The size of the complex-valued operator.
      28              :   int height, width;
      29              : 
      30              : public:
      31            9 :   ComplexOperator(int s = 0) : height(s), width(s) {}
      32            9 :   ComplexOperator(int h, int w) : height(h), width(w) {}
      33              :   virtual ~ComplexOperator() = default;
      34              : 
      35              :   // Get the height (size of output) of the operator.
      36            6 :   int Height() const { return height; }
      37              : 
      38              :   // Get the width (size of input) of the operator.
      39            3 :   int Width() const { return width; }
      40              : 
      41              :   // Test whether or not the operator is purely real or imaginary.
      42            0 :   virtual bool IsReal() const { return !Imag(); }
      43            0 :   virtual bool IsImag() const { return !Real(); }
      44              : 
      45              :   // Get access to the real and imaginary operator parts separately (may be empty if
      46              :   // operator is purely real or imaginary).
      47              :   virtual const Operator *Real() const;
      48              :   virtual const Operator *Imag() const;
      49              : 
      50              :   // Diagonal assembly.
      51              :   virtual void AssembleDiagonal(ComplexVector &diag) const;
      52              : 
      53              :   // Operator application.
      54              :   virtual void Mult(const ComplexVector &x, ComplexVector &y) const = 0;
      55              : 
      56              :   virtual void MultTranspose(const ComplexVector &x, ComplexVector &y) const;
      57              : 
      58              :   virtual void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const;
      59              : 
      60              :   virtual void AddMult(const ComplexVector &x, ComplexVector &y,
      61              :                        const std::complex<double> a = 1.0) const;
      62              : 
      63              :   virtual void AddMultTranspose(const ComplexVector &x, ComplexVector &y,
      64              :                                 const std::complex<double> a = 1.0) const;
      65              : 
      66              :   virtual void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
      67              :                                          const std::complex<double> a = 1.0) const;
      68              : };
      69              : 
      70              : // A complex-valued operator represented using a block 2 x 2 equivalent-real formulation:
      71              : //                          [ yr ]  =  [ Ar  -Ai ] [ xr ]
      72              : //                          [ yi ]     [ Ai   Ar ] [ xi ] .
      73              : class ComplexWrapperOperator : public ComplexOperator
      74              : {
      75              : private:
      76              :   // Storage and access for real and imaginary parts of the operator.
      77              :   std::unique_ptr<Operator> data_Ar, data_Ai;
      78              :   const Operator *Ar, *Ai;
      79              : 
      80              :   // Temporary storage for operator application.
      81              :   mutable ComplexVector tx, ty;
      82              : 
      83              :   ComplexWrapperOperator(std::unique_ptr<Operator> &&dAr, std::unique_ptr<Operator> &&dAi,
      84              :                          const Operator *pAr, const Operator *pAi);
      85              : 
      86              : public:
      87              :   // Construct a complex operator which inherits ownership of the input real and imaginary
      88              :   // parts.
      89              :   ComplexWrapperOperator(std::unique_ptr<Operator> &&Ar, std::unique_ptr<Operator> &&Ai);
      90              : 
      91              :   // Non-owning constructor.
      92              :   ComplexWrapperOperator(const Operator *Ar, const Operator *Ai);
      93              : 
      94           42 :   const Operator *Real() const override { return Ar; }
      95           42 :   const Operator *Imag() const override { return Ai; }
      96              : 
      97              :   void AssembleDiagonal(ComplexVector &diag) const override;
      98              : 
      99              :   void Mult(const ComplexVector &x, ComplexVector &y) const override;
     100              : 
     101              :   void MultTranspose(const ComplexVector &x, ComplexVector &y) const override;
     102              : 
     103              :   void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override;
     104              : 
     105              :   void AddMult(const ComplexVector &x, ComplexVector &y,
     106              :                const std::complex<double> a = 1.0) const override;
     107              : 
     108              :   void AddMultTranspose(const ComplexVector &x, ComplexVector &y,
     109              :                         const std::complex<double> a = 1.0) const override;
     110              : 
     111              :   void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
     112              :                                  const std::complex<double> a = 1.0) const override;
     113              : };
     114              : 
     115              : // Wrap a sequence of operators of the same dimensions and optional coefficients.
     116              : class SumOperator : public Operator
     117              : {
     118              : private:
     119              :   std::vector<std::pair<const Operator *, double>> ops;
     120              :   mutable Vector z;
     121              : 
     122              : public:
     123              :   SumOperator(int s) : Operator(s) { z.UseDevice(true); }
     124            9 :   SumOperator(int h, int w) : Operator(h, w) { z.UseDevice(true); }
     125              :   SumOperator(const Operator &op, double a = 1.0);
     126              : 
     127              :   void AddOperator(const Operator &op, double a = 1.0);
     128              : 
     129              :   void Mult(const Vector &x, Vector &y) const override;
     130              : 
     131              :   void MultTranspose(const Vector &x, Vector &y) const override;
     132              : 
     133              :   void AddMult(const Vector &x, Vector &y, const double a = 1.0) const override;
     134              : 
     135              :   void AddMultTranspose(const Vector &x, Vector &y, const double a = 1.0) const override;
     136              : };
     137              : 
     138              : // Wraps two operators such that: (AB)ᵀ = BᵀAᵀ and, for complex symmetric operators, the
     139              : // Hermitian transpose operation is (AB)ᴴ = BᴴAᴴ.
     140              : template <typename ProductOperator, typename OperType>
     141              : class ProductOperatorHelper : public OperType
     142              : {
     143              : };
     144              : 
     145              : template <typename ProductOperator>
     146              : class ProductOperatorHelper<ProductOperator, Operator> : public Operator
     147              : {
     148              : public:
     149              :   ProductOperatorHelper(int h, int w) : Operator(h, w) {}
     150              : };
     151              : 
     152              : template <typename ProductOperator>
     153              : class ProductOperatorHelper<ProductOperator, ComplexOperator> : public ComplexOperator
     154              : {
     155              : public:
     156              :   ProductOperatorHelper(int h, int w) : ComplexOperator(h, w) {}
     157              : 
     158            0 :   void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override
     159              :   {
     160            0 :     const ComplexOperator &A = static_cast<const ProductOperator *>(this)->A;
     161            0 :     const ComplexOperator &B = static_cast<const ProductOperator *>(this)->B;
     162            0 :     ComplexVector &z = static_cast<const ProductOperator *>(this)->z;
     163            0 :     A.MultHermitianTranspose(x, z);
     164            0 :     B.MultHermitianTranspose(z, y);
     165            0 :   }
     166              : 
     167            0 :   void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
     168              :                                  const std::complex<double> a = 1.0) const override
     169              :   {
     170            0 :     const ComplexOperator &A = static_cast<const ProductOperator *>(this)->A;
     171            0 :     const ComplexOperator &B = static_cast<const ProductOperator *>(this)->B;
     172            0 :     ComplexVector &z = static_cast<const ProductOperator *>(this)->z;
     173            0 :     A.MultHermitianTranspose(x, z);
     174            0 :     B.AddMultHermitianTranspose(z, y, a);
     175            0 :   }
     176              : };
     177              : 
     178              : template <typename OperType>
     179            0 : class BaseProductOperator
     180              :   : public ProductOperatorHelper<BaseProductOperator<OperType>, OperType>
     181              : {
     182              :   friend class ProductOperatorHelper<BaseProductOperator<OperType>, OperType>;
     183              : 
     184              :   using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
     185              :                                             ComplexVector, Vector>::type;
     186              :   using ScalarType =
     187              :       typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
     188              :                                 std::complex<double>, double>::type;
     189              : 
     190              : private:
     191              :   const OperType &A, &B;
     192              :   mutable VecType z;
     193              : 
     194              : public:
     195            0 :   BaseProductOperator(const OperType &A, const OperType &B)
     196              :     : ProductOperatorHelper<BaseProductOperator<OperType>, OperType>(A.Height(), B.Width()),
     197            0 :       A(A), B(B), z(B.Height())
     198              :   {
     199            0 :     z.UseDevice(true);
     200            0 :   }
     201              : 
     202            0 :   void Mult(const VecType &x, VecType &y) const override
     203              :   {
     204            0 :     B.Mult(x, z);
     205            0 :     A.Mult(z, y);
     206            0 :   }
     207              : 
     208            0 :   void MultTranspose(const VecType &x, VecType &y) const override
     209              :   {
     210            0 :     A.MultTranspose(x, z);
     211            0 :     B.MultTranspose(z, y);
     212            0 :   }
     213              : 
     214            0 :   void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override
     215              :   {
     216            0 :     B.Mult(x, z);
     217            0 :     A.AddMult(z, y, a);
     218            0 :   }
     219              : 
     220            0 :   void AddMultTranspose(const VecType &x, VecType &y,
     221              :                         const ScalarType a = 1.0) const override
     222              :   {
     223            0 :     A.MultTranspose(x, z);
     224            0 :     B.AddMultTranspose(z, y, a);
     225            0 :   }
     226              : };
     227              : 
     228              : using ProductOperator = BaseProductOperator<Operator>;
     229              : using ComplexProductOperator = BaseProductOperator<ComplexOperator>;
     230              : 
     231              : // Applies the simple, symmetric but not necessarily Hermitian, operator: diag(d).
     232              : template <typename DiagonalOperator, typename OperType>
     233              : class DiagonalOperatorHelper : public OperType
     234              : {
     235              : };
     236              : 
     237              : template <typename DiagonalOperator>
     238              : class DiagonalOperatorHelper<DiagonalOperator, Operator> : public Operator
     239              : {
     240              : public:
     241              :   DiagonalOperatorHelper(int s) : Operator(s) {}
     242              : };
     243              : 
     244              : template <typename DiagonalOperator>
     245              : class DiagonalOperatorHelper<DiagonalOperator, ComplexOperator> : public ComplexOperator
     246              : {
     247              : public:
     248              :   DiagonalOperatorHelper(int s) : ComplexOperator(s) {}
     249              : 
     250              :   void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override;
     251              : 
     252              :   void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
     253              :                                  const std::complex<double> a = 1.0) const override;
     254              : };
     255              : 
     256              : template <typename OperType>
     257            0 : class BaseDiagonalOperator
     258              :   : public DiagonalOperatorHelper<BaseDiagonalOperator<OperType>, OperType>
     259              : {
     260              :   friend class DiagonalOperatorHelper<BaseDiagonalOperator<OperType>, OperType>;
     261              : 
     262              :   using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
     263              :                                             ComplexVector, Vector>::type;
     264              :   using ScalarType =
     265              :       typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
     266              :                                 std::complex<double>, double>::type;
     267              : 
     268              : private:
     269              :   const VecType &d;
     270              : 
     271              : public:
     272            0 :   BaseDiagonalOperator(const VecType &d)
     273            0 :     : DiagonalOperatorHelper<BaseDiagonalOperator<OperType>, OperType>(d.Size()), d(d)
     274              :   {
     275              :   }
     276              : 
     277              :   void Mult(const VecType &x, VecType &y) const override;
     278              : 
     279            0 :   void MultTranspose(const VecType &x, VecType &y) const override { Mult(x, y); }
     280              : 
     281              :   void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override;
     282              : 
     283            0 :   void AddMultTranspose(const VecType &x, VecType &y,
     284              :                         const ScalarType a = 1.0) const override
     285              :   {
     286            0 :     AddMult(x, y, a);
     287            0 :   }
     288              : };
     289              : 
     290              : using DiagonalOperator = BaseDiagonalOperator<Operator>;
     291              : using ComplexDiagonalOperator = BaseDiagonalOperator<ComplexOperator>;
     292              : 
     293              : // A container for a sequence of operators corresponding to a multigrid hierarchy.
     294              : // Optionally includes operators for the auxiliary space at each level as well. The
     295              : // Operators are stored from coarsest to finest level. The height and width of this operator
     296              : // are never set.
     297              : template <typename OperType>
     298              : class BaseMultigridOperator : public OperType
     299              : {
     300              :   using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
     301              :                                             ComplexVector, Vector>::type;
     302              :   using ScalarType =
     303              :       typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
     304              :                                 std::complex<double>, double>::type;
     305              : 
     306              : private:
     307              :   std::vector<std::unique_ptr<OperType>> ops, aux_ops;
     308              : 
     309              : public:
     310            0 :   BaseMultigridOperator(std::size_t l) : OperType(0)
     311              :   {
     312            0 :     ops.reserve(l);
     313            0 :     aux_ops.reserve(l);
     314            0 :   }
     315              : 
     316              :   void AddOperator(std::unique_ptr<OperType> &&op)
     317              :   {
     318            0 :     ops.push_back(std::move(op));
     319            0 :     this->height = ops.back()->Height();
     320            0 :     this->width = ops.back()->Width();
     321              :   }
     322              : 
     323              :   void AddAuxiliaryOperator(std::unique_ptr<OperType> &&aux_op)
     324              :   {
     325            0 :     aux_ops.push_back(std::move(aux_op));
     326            0 :   }
     327              : 
     328              :   bool HasAuxiliaryOperators() const { return !aux_ops.empty(); }
     329              :   auto GetNumLevels() const { return ops.size(); }
     330              :   auto GetNumAuxiliaryLevels() const { return aux_ops.size(); }
     331              : 
     332              :   const OperType &GetFinestOperator() const { return *ops.back(); }
     333              :   const OperType &GetFinestAuxiliaryOperator() const { return *aux_ops.back(); }
     334              : 
     335              :   const OperType &GetOperatorAtLevel(std::size_t l) const
     336              :   {
     337              :     MFEM_ASSERT(l < GetNumLevels(), "Out of bounds multigrid level operator requested!");
     338              :     return *ops[l];
     339              :   }
     340              :   const OperType &GetAuxiliaryOperatorAtLevel(std::size_t l) const
     341              :   {
     342              :     MFEM_ASSERT(l < GetNumAuxiliaryLevels(),
     343              :                 "Out of bounds multigrid level auxiliary operator requested!");
     344              :     return *aux_ops[l];
     345              :   }
     346              : 
     347            0 :   void Mult(const VecType &x, VecType &y) const override { GetFinestOperator().Mult(x, y); }
     348              : 
     349            0 :   void MultTranspose(const VecType &x, VecType &y) const override
     350              :   {
     351            0 :     GetFinestOperator().MultTranspose(x, y);
     352            0 :   }
     353              : 
     354            0 :   void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override
     355              :   {
     356            0 :     GetFinestOperator().AddMult(x, y, a);
     357            0 :   }
     358              : 
     359            0 :   void AddMultTranspose(const VecType &x, VecType &y,
     360              :                         const ScalarType a = 1.0) const override
     361              :   {
     362            0 :     GetFinestOperator().AddMultTranspose(x, y, a);
     363            0 :   }
     364              : };
     365              : 
     366              : using MultigridOperator = BaseMultigridOperator<Operator>;
     367              : using ComplexMultigridOperator = BaseMultigridOperator<ComplexOperator>;
     368              : 
     369              : namespace linalg
     370              : {
     371              : 
     372              : // Calculate the vector norm with respect to an SPD matrix B.
     373              : template <typename VecType>
     374              : double Norml2(MPI_Comm comm, const VecType &x, const Operator &B, VecType &Bx);
     375              : 
     376              : // Normalize the vector with respect to an SPD matrix B.
     377              : template <typename VecType>
     378              : inline double Normalize(MPI_Comm comm, VecType &x, const Operator &B, VecType &Bx)
     379              : {
     380              :   double norm = Norml2(comm, x, B, Bx);
     381              :   MFEM_ASSERT(norm > 0.0, "Zero vector norm in normalization!");
     382              :   x *= 1.0 / norm;
     383              :   return norm;
     384              : }
     385              : 
     386              : // Estimate operator 2-norm (spectral norm) using power iteration. Assumes the operator is
     387              : // not symmetric or Hermitian unless specified.
     388              : double SpectralNorm(MPI_Comm comm, const Operator &A, bool sym = false, double tol = 1.0e-4,
     389              :                     int max_it = 1000);
     390              : double SpectralNorm(MPI_Comm comm, const ComplexOperator &A, bool herm = false,
     391              :                     double tol = 1.0e-4, int max_it = 1000);
     392              : 
     393              : }  // namespace linalg
     394              : 
     395              : }  // namespace palace
     396              : 
     397              : #endif  // PALACE_LINALG_OPERATOR_HPP
        

Generated by: LCOV version 2.0-1