LCOV - code coverage report
Current view: top level - linalg - rap.hpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 76.9 % 13 10
Test Date: 2025-10-23 22:45:05 Functions: 0.0 % 2 0
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #ifndef PALACE_LINALG_RAP_HPP
       5              : #define PALACE_LINALG_RAP_HPP
       6              : 
       7              : #include <array>
       8              : #include <memory>
       9              : #include <mfem.hpp>
      10              : #include "fem/fespace.hpp"
      11              : #include "linalg/operator.hpp"
      12              : #include "linalg/vector.hpp"
      13              : 
      14              : namespace palace
      15              : {
      16              : 
      17              : //
      18              : // A parallel operator represented by RAP constructed through the actions of R, A, and P,
      19              : // usually with R = Páµ€, and with possible eliminated essential BC. Here R and P are the
      20              : // parallel restriction and prolongation matrices.
      21              : //
      22              : 
      23              : // Real-valued RAP operator.
      24              : class ParOperator : public Operator
      25              : {
      26              : private:
      27              :   // Storage and access for the local operator.
      28              :   std::unique_ptr<Operator> data_A;
      29              :   const Operator *A;
      30              : 
      31              :   // Finite element spaces for parallel prolongation and restriction.
      32              :   const FiniteElementSpace &trial_fespace, &test_fespace;
      33              :   const bool use_R;
      34              : 
      35              :   // Lists of constrained essential boundary true dofs for elimination.
      36              :   mfem::Array<int> dbc_tdof_list;
      37              : 
      38              :   // Diagonal policy for constrained true dofs.
      39              :   DiagonalPolicy diag_policy = DiagonalPolicy::DIAG_ZERO;
      40              : 
      41              :   // Assembled operator as a parallel Hypre matrix. If assembled, the local operator is not
      42              :   // deleted.
      43              :   mutable std::unique_ptr<mfem::HypreParMatrix> RAP;
      44              : 
      45              :   // Helper methods for operator application.
      46              :   void RestrictionMatrixMult(const Vector &ly, Vector &ty) const;
      47              :   void RestrictionMatrixMultTranspose(const Vector &ty, Vector &ly) const;
      48              :   Vector &GetTestLVector() const;
      49              : 
      50              :   ParOperator(std::unique_ptr<Operator> &&dA, const Operator *pA,
      51              :               const FiniteElementSpace &trial_fespace,
      52              :               const FiniteElementSpace &test_fespace, bool test_restrict);
      53              : 
      54              : public:
      55              :   // Construct the parallel operator, inheriting ownership of the local operator.
      56              :   ParOperator(std::unique_ptr<Operator> &&A, const FiniteElementSpace &trial_fespace,
      57              :               const FiniteElementSpace &test_fespace, bool test_restrict);
      58              :   ParOperator(std::unique_ptr<Operator> &&A, const FiniteElementSpace &fespace)
      59            9 :     : ParOperator(std::move(A), fespace, fespace, false)
      60              :   {
      61            9 :   }
      62              : 
      63              :   // Non-owning constructors.
      64              :   ParOperator(const Operator &A, const FiniteElementSpace &trial_fespace,
      65              :               const FiniteElementSpace &test_fespace, bool test_restrict);
      66              :   ParOperator(const Operator &A, const FiniteElementSpace &fespace)
      67              :     : ParOperator(A, fespace, fespace, false)
      68              :   {
      69              :   }
      70              : 
      71              :   // Get access to the underlying local (L-vector) operator.
      72           12 :   const Operator &LocalOperator() const { return *A; }
      73              : 
      74              :   // Get the associated MPI communicator.
      75              :   MPI_Comm GetComm() const { return trial_fespace.GetComm(); }
      76              : 
      77              :   // Accessor for trial finite element space.
      78            9 :   const FiniteElementSpace &TrialFiniteElementSpace() const { return trial_fespace; }
      79              : 
      80              :   // Accessor for test finite element space.
      81              :   const FiniteElementSpace &TestFiniteElementSpace() const { return test_fespace; }
      82              : 
      83              :   // Set essential boundary condition true dofs for square operators.
      84              :   void SetEssentialTrueDofs(const mfem::Array<int> &tdof_list, DiagonalPolicy policy);
      85              : 
      86              :   // Get the essential boundary condition true dofs associated with the operator. May be
      87              :   // nullptr.
      88              :   const mfem::Array<int> *GetEssentialTrueDofs() const
      89              :   {
      90            6 :     return dbc_tdof_list.Size() ? &dbc_tdof_list : nullptr;
      91              :   }
      92              : 
      93              :   // Get the diagonal policy that was most recently used. If there are no essential dofs,
      94              :   // and thus no valid policy, will error.
      95              :   DiagonalPolicy GetDiagonalPolicy() const;
      96              : 
      97              :   // Eliminate essential true dofs from the RHS vector b, using the essential boundary
      98              :   // condition values in x.
      99              :   void EliminateRHS(const Vector &x, Vector &b) const;
     100              : 
     101              :   // Assemble the operator as a parallel sparse matrix. The memory associated with the
     102              :   // local operator is free'd.
     103              :   mfem::HypreParMatrix &ParallelAssemble(bool skip_zeros = false) const;
     104              : 
     105              :   // Steal the assembled parallel sparse matrix.
     106              :   std::unique_ptr<mfem::HypreParMatrix> StealParallelAssemble(bool skip_zeros = false) const
     107              :   {
     108            0 :     ParallelAssemble(skip_zeros);
     109              :     return std::move(RAP);
     110              :   }
     111              : 
     112              :   void AssembleDiagonal(Vector &diag) const override;
     113              : 
     114              :   void Mult(const Vector &x, Vector &y) const override;
     115              : 
     116              :   void MultTranspose(const Vector &x, Vector &y) const override;
     117              : 
     118              :   void AddMult(const Vector &x, Vector &y, const double a = 1.0) const override;
     119              : 
     120              :   void AddMultTranspose(const Vector &x, Vector &y, const double a = 1.0) const override;
     121              : };
     122              : 
     123              : // Complex-valued RAP operator.
     124              : class ComplexParOperator : public ComplexOperator
     125              : {
     126              : private:
     127              :   // Storage and access for the local operator.
     128              :   std::unique_ptr<ComplexWrapperOperator> data_A;
     129              :   const ComplexWrapperOperator *A;
     130              : 
     131              :   // Finite element spaces for parallel prolongation and restriction.
     132              :   const FiniteElementSpace &trial_fespace, &test_fespace;
     133              :   const bool use_R;
     134              : 
     135              :   // Lists of constrained essential boundary true dofs for elimination.
     136              :   mfem::Array<int> dbc_tdof_list;
     137              : 
     138              :   // Diagonal policy for constrained true dofs.
     139              :   Operator::DiagonalPolicy diag_policy = Operator::DiagonalPolicy::DIAG_ZERO;
     140              : 
     141              :   // Real and imaginary parts of the operator as non-owning ParOperator objects.
     142              :   std::unique_ptr<ParOperator> RAPr, RAPi;
     143              : 
     144              :   // Helper methods for operator application.
     145              :   void RestrictionMatrixMult(const ComplexVector &ly, ComplexVector &ty) const;
     146              :   void RestrictionMatrixMultTranspose(const ComplexVector &ty, ComplexVector &ly) const;
     147              :   ComplexVector &GetTestLVector() const;
     148              : 
     149              :   ComplexParOperator(std::unique_ptr<Operator> &&dAr, std::unique_ptr<Operator> &&dAi,
     150              :                      const Operator *pAr, const Operator *pAi,
     151              :                      const FiniteElementSpace &trial_fespace,
     152              :                      const FiniteElementSpace &test_fespace, bool test_restrict);
     153              : 
     154              : public:
     155              :   // Construct the complex-valued parallel operator from the separate real and imaginary
     156              :   // parts, inheriting ownership of the local operator.
     157              :   ComplexParOperator(std::unique_ptr<Operator> &&Ar, std::unique_ptr<Operator> &&Ai,
     158              :                      const FiniteElementSpace &trial_fespace,
     159              :                      const FiniteElementSpace &test_fespace, bool test_restrict);
     160              :   ComplexParOperator(std::unique_ptr<Operator> &&Ar, std::unique_ptr<Operator> &&Ai,
     161              :                      const FiniteElementSpace &fespace)
     162            9 :     : ComplexParOperator(std::move(Ar), std::move(Ai), fespace, fespace, false)
     163              :   {
     164            9 :   }
     165              : 
     166              :   // Non-owning constructors.
     167              :   ComplexParOperator(const Operator *Ar, const Operator *Ai,
     168              :                      const FiniteElementSpace &trial_fespace,
     169              :                      const FiniteElementSpace &test_fespace, bool test_restrict);
     170              :   ComplexParOperator(const Operator *Ar, const Operator *Ai,
     171              :                      const FiniteElementSpace &fespace)
     172              :     : ComplexParOperator(Ar, Ai, fespace, fespace, false)
     173              :   {
     174              :   }
     175              : 
     176            0 :   const Operator *Real() const override { return RAPr.get(); }
     177            0 :   const Operator *Imag() const override { return RAPi.get(); }
     178              : 
     179              :   // Get access to the underlying local (L-vector) operator.
     180           51 :   const ComplexOperator &LocalOperator() const { return *A; }
     181              : 
     182              :   // Get the associated MPI communicator.
     183              :   MPI_Comm GetComm() const { return trial_fespace.GetComm(); }
     184              : 
     185              :   // Accessor for trial finite element space.
     186            9 :   const FiniteElementSpace &TrialFiniteElementSpace() const { return trial_fespace; }
     187              : 
     188              :   // Accessor for test finite element space.
     189              :   const FiniteElementSpace &TestFiniteElementSpace() const { return test_fespace; }
     190              : 
     191              :   // Set essential boundary condition true dofs for square operators.
     192              :   void SetEssentialTrueDofs(const mfem::Array<int> &tdof_list,
     193              :                             Operator::DiagonalPolicy policy);
     194              : 
     195              :   // Get the essential boundary condition true dofs associated with the operator. May be
     196              :   // nullptr.
     197              :   const mfem::Array<int> *GetEssentialTrueDofs() const
     198              :   {
     199            6 :     return dbc_tdof_list.Size() ? &dbc_tdof_list : nullptr;
     200              :   }
     201              : 
     202              :   // Get the diagonal policy that was most recently used. If there are no essential dofs,
     203              :   // and thus no valid policy, will error.
     204              :   Operator::DiagonalPolicy GetDiagonalPolicy() const;
     205              : 
     206              :   void AssembleDiagonal(ComplexVector &diag) const override;
     207              : 
     208              :   void Mult(const ComplexVector &x, ComplexVector &y) const override;
     209              : 
     210              :   void MultTranspose(const ComplexVector &x, ComplexVector &y) const override;
     211              : 
     212              :   void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override;
     213              : 
     214              :   void AddMult(const ComplexVector &x, ComplexVector &y,
     215              :                const std::complex<double> a = 1.0) const override;
     216              : 
     217              :   void AddMultTranspose(const ComplexVector &x, ComplexVector &y,
     218              :                         const std::complex<double> a = 1.0) const override;
     219              : 
     220              :   void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
     221              :                                  const std::complex<double> a = 1.0) const override;
     222              : };
     223              : 
     224              : // Combine a collection of ParOperator into a weighted summation. If set_essential is true,
     225              : // extract the essential dofs from the operator array, and apply to the summed operator.
     226              : // Requires explicit instantiation.
     227              : template <std::size_t N>
     228              : std::unique_ptr<ParOperator>
     229              : BuildParSumOperator(const std::array<double, N> &coeff,
     230              :                     const std::array<const ParOperator *, N> &ops,
     231              :                     bool set_essential = true);
     232              : 
     233              : // Combine a collection of ComplexParOperator into a weighted summation. If set_essential is
     234              : // true, extract the essential dofs from the operator array, and apply to the summed
     235              : // operator. Requires explicit instantiation.
     236              : template <std::size_t N>
     237              : std::unique_ptr<ComplexParOperator>
     238              : BuildParSumOperator(const std::array<std::complex<double>, N> &coeff,
     239              :                     const std::array<const ComplexParOperator *, N> &ops,
     240              :                     bool set_essential = true);
     241              : 
     242              : // Dispatcher to convert initializer list or C arrays into std::array whilst deducing sizes
     243              : // and types.
     244              : template <std::size_t N, typename ScalarType, typename OperType>
     245              : std::unique_ptr<std::conditional_t<std::is_base_of_v<ComplexOperator, OperType>,
     246              :                                    ComplexParOperator, ParOperator>>
     247              : BuildParSumOperator(ScalarType (&&coeff_in)[N], const OperType *(&&ops_in)[N],
     248              :                     bool set_essential = true);
     249              : 
     250              : }  // namespace palace
     251              : 
     252              : #endif  // PALACE_LINALG_RAP_HPP
        

Generated by: LCOV version 2.0-1