LCOV - code coverage report
Current view: top level - linalg - iterative.hpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 0.0 % 34 0
Test Date: 2025-10-23 22:45:05 Functions: 0.0 % 40 0
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #ifndef PALACE_LINALG_ITERATIVE_HPP
       5              : #define PALACE_LINALG_ITERATIVE_HPP
       6              : 
       7              : #include <type_traits>
       8              : #include <vector>
       9              : #include <mfem.hpp>
      10              : #include "linalg/operator.hpp"
      11              : #include "linalg/solver.hpp"
      12              : #include "linalg/vector.hpp"
      13              : #include "utils/labels.hpp"
      14              : 
      15              : namespace palace
      16              : {
      17              : 
      18              : //
      19              : // Iterative solvers based on Krylov subspace methods with optional preconditioning, for
      20              : // real- or complex-valued systems.
      21              : //
      22              : 
      23              : // Base class for iterative solvers based on Krylov subspace methods with optional
      24              : // preconditioning.
      25              : template <typename OperType>
      26              : class IterativeSolver : public Solver<OperType>
      27              : {
      28              : protected:
      29              :   using RealType = double;
      30              :   using ScalarType =
      31              :       typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
      32              :                                 std::complex<RealType>, RealType>::type;
      33              : 
      34              :   // MPI communicator associated with the solver.
      35              :   MPI_Comm comm;
      36              : 
      37              :   // Control level of printing during solves.
      38              :   mfem::IterativeSolver::PrintLevel print_opts;
      39              :   int int_width, tab_width;
      40              : 
      41              :   // Relative and absolute tolerances.
      42              :   double rel_tol, abs_tol;
      43              : 
      44              :   // Limit for the number of solver iterations.
      45              :   int max_it;
      46              : 
      47              :   // Operator and (optional) preconditioner associated with the iterative solver (not
      48              :   // owned).
      49              :   const OperType *A;
      50              :   const Solver<OperType> *B;
      51              : 
      52              :   // Variables set during solve to capture solve statistics.
      53              :   mutable bool converged;
      54              :   mutable double initial_res, final_res;
      55              :   mutable int final_it;
      56              : 
      57              :   // Enable timer contribution for Timer::PRECONDITIONER.
      58              :   bool use_timer;
      59              : 
      60              : public:
      61              :   IterativeSolver(MPI_Comm comm, int print);
      62              : 
      63              :   // Set an indentation for all log printing.
      64            0 :   void SetTabWidth(int width) { tab_width = width; }
      65              : 
      66              :   // Set the relative convergence tolerance.
      67            0 :   void SetTol(double tol) { SetRelTol(tol); }
      68            0 :   void SetRelTol(double tol) { rel_tol = tol; }
      69              : 
      70              :   // Set the absolute convergence tolerance.
      71            0 :   void SetAbsTol(double tol) { abs_tol = tol; }
      72              : 
      73              :   // Set the maximum number of iterations.
      74            0 :   void SetMaxIter(int its)
      75              :   {
      76            0 :     max_it = its;
      77            0 :     int_width = 1 + static_cast<int>(std::log10(its));
      78            0 :   }
      79              : 
      80              :   // Set the operator for the solver.
      81            0 :   void SetOperator(const OperType &op) override
      82              :   {
      83            0 :     A = &op;
      84            0 :     this->height = op.Height();
      85            0 :     this->width = op.Width();
      86            0 :   }
      87              : 
      88              :   // Set the preconditioner for the solver.
      89            0 :   void SetPreconditioner(const Solver<OperType> &pc) { B = &pc; }
      90              : 
      91              :   // Returns if the previous solve converged or not.
      92            0 :   bool GetConverged() const { return converged && (rel_tol > 0.0 || abs_tol > 0.0); }
      93              : 
      94              :   // Returns the initial (absolute) residual for the previous solve.
      95            0 :   double GetInitialRes() const { return initial_res; }
      96              : 
      97              :   // Returns the final (absolute) residual for the previous solve, which may be an estimate
      98              :   // to the true residual.
      99            0 :   double GetFinalRes() const { return final_res; }
     100              : 
     101              :   // Returns the number of iterations for the previous solve.
     102            0 :   int GetNumIterations() const { return final_it; }
     103              : 
     104              :   // Get the associated MPI communicator.
     105            0 :   MPI_Comm GetComm() const { return comm; }
     106              : 
     107              :   // Activate preconditioner timing during solves.
     108            0 :   void EnableTimer() { use_timer = true; }
     109              : };
     110              : 
     111              : // Preconditioned Conjugate Gradient (CG) method for SPD linear systems.
     112              : template <typename OperType>
     113              : class CgSolver : public IterativeSolver<OperType>
     114              : {
     115              : protected:
     116              :   using VecType = typename Solver<OperType>::VecType;
     117              :   using RealType = typename IterativeSolver<OperType>::RealType;
     118              :   using ScalarType = typename IterativeSolver<OperType>::ScalarType;
     119              : 
     120              :   using IterativeSolver<OperType>::comm;
     121              :   using IterativeSolver<OperType>::print_opts;
     122              :   using IterativeSolver<OperType>::int_width;
     123              :   using IterativeSolver<OperType>::tab_width;
     124              : 
     125              :   using IterativeSolver<OperType>::rel_tol;
     126              :   using IterativeSolver<OperType>::abs_tol;
     127              :   using IterativeSolver<OperType>::max_it;
     128              : 
     129              :   using IterativeSolver<OperType>::A;
     130              :   using IterativeSolver<OperType>::B;
     131              : 
     132              :   using IterativeSolver<OperType>::converged;
     133              :   using IterativeSolver<OperType>::initial_res;
     134              :   using IterativeSolver<OperType>::final_res;
     135              :   using IterativeSolver<OperType>::final_it;
     136              : 
     137              :   // Temporary workspace for solve.
     138              :   mutable VecType r, z, p;
     139              : 
     140              : public:
     141            0 :   CgSolver(MPI_Comm comm, int print) : IterativeSolver<OperType>(comm, print) {}
     142              : 
     143              :   void Mult(const VecType &b, VecType &x) const override;
     144              : };
     145              : 
     146              : // Preconditioned Generalized Minimum Residual Method (GMRES) for general nonsymmetric
     147              : // linear systems.
     148              : template <typename OperType>
     149              : class GmresSolver : public IterativeSolver<OperType>
     150              : {
     151              : protected:
     152              :   using VecType = typename Solver<OperType>::VecType;
     153              :   using RealType = typename IterativeSolver<OperType>::RealType;
     154              :   using ScalarType = typename IterativeSolver<OperType>::ScalarType;
     155              : 
     156              :   using IterativeSolver<OperType>::comm;
     157              :   using IterativeSolver<OperType>::print_opts;
     158              :   using IterativeSolver<OperType>::int_width;
     159              :   using IterativeSolver<OperType>::tab_width;
     160              : 
     161              :   using IterativeSolver<OperType>::rel_tol;
     162              :   using IterativeSolver<OperType>::abs_tol;
     163              :   using IterativeSolver<OperType>::max_it;
     164              : 
     165              :   using IterativeSolver<OperType>::A;
     166              :   using IterativeSolver<OperType>::B;
     167              : 
     168              :   using IterativeSolver<OperType>::converged;
     169              :   using IterativeSolver<OperType>::initial_res;
     170              :   using IterativeSolver<OperType>::final_res;
     171              :   using IterativeSolver<OperType>::final_it;
     172              : 
     173              :   // Maximum subspace dimension for restarted GMRES.
     174              :   mutable int max_dim;
     175              : 
     176              :   // Orthogonalization method for orthonormalizing a newly computed vector against a basis
     177              :   // at each iteration.
     178              :   Orthogonalization gs_orthog;
     179              : 
     180              :   // Use left or right preconditioning.
     181              :   PreconditionerSide pc_side;
     182              : 
     183              :   // Temporary workspace for solve.
     184              :   mutable std::vector<VecType> V;
     185              :   mutable VecType r;
     186              :   mutable std::vector<ScalarType> H;
     187              :   mutable std::vector<ScalarType> s, sn;
     188              :   mutable std::vector<RealType> cs;
     189              : 
     190              :   // Allocate storage for solve.
     191              :   virtual void Initialize() const;
     192              :   virtual void Update(int j) const;
     193              : 
     194              : public:
     195            0 :   GmresSolver(MPI_Comm comm, int print)
     196            0 :     : IterativeSolver<OperType>(comm, print), max_dim(-1),
     197            0 :       gs_orthog(Orthogonalization::MGS), pc_side(PreconditionerSide::LEFT)
     198              :   {
     199            0 :   }
     200              : 
     201              :   // Set the dimension for restart.
     202            0 :   void SetRestartDim(int dim) { max_dim = dim; }
     203              : 
     204              :   // Set the orthogonalization method.
     205            0 :   void SetOrthogonalization(Orthogonalization orthog) { gs_orthog = orthog; }
     206              : 
     207              :   // Set the side for preconditioning.
     208            0 :   virtual void SetPreconditionerSide(PreconditionerSide side) { pc_side = side; }
     209              : 
     210              :   void Mult(const VecType &b, VecType &x) const override;
     211              : };
     212              : 
     213              : // Preconditioned Flexible Generalized Minimum Residual Method (FGMRES) for general
     214              : // nonsymmetric linear systems with a non-constant preconditioner.
     215              : template <typename OperType>
     216              : class FgmresSolver : public GmresSolver<OperType>
     217              : {
     218              : protected:
     219              :   using VecType = typename GmresSolver<OperType>::VecType;
     220              :   using RealType = typename GmresSolver<OperType>::RealType;
     221              :   using ScalarType = typename GmresSolver<OperType>::ScalarType;
     222              : 
     223              :   using GmresSolver<OperType>::comm;
     224              :   using GmresSolver<OperType>::print_opts;
     225              :   using GmresSolver<OperType>::int_width;
     226              :   using GmresSolver<OperType>::tab_width;
     227              : 
     228              :   using GmresSolver<OperType>::rel_tol;
     229              :   using GmresSolver<OperType>::abs_tol;
     230              :   using GmresSolver<OperType>::max_it;
     231              : 
     232              :   using GmresSolver<OperType>::A;
     233              :   using GmresSolver<OperType>::B;
     234              : 
     235              :   using GmresSolver<OperType>::converged;
     236              :   using GmresSolver<OperType>::initial_res;
     237              :   using GmresSolver<OperType>::final_res;
     238              :   using GmresSolver<OperType>::final_it;
     239              : 
     240              :   using GmresSolver<OperType>::max_dim;
     241              :   using GmresSolver<OperType>::gs_orthog;
     242              :   using GmresSolver<OperType>::pc_side;
     243              :   using GmresSolver<OperType>::V;
     244              :   using GmresSolver<OperType>::H;
     245              :   using GmresSolver<OperType>::s;
     246              :   using GmresSolver<OperType>::sn;
     247              :   using GmresSolver<OperType>::cs;
     248              : 
     249              :   // Temporary workspace for solve.
     250              :   mutable std::vector<VecType> Z;
     251              : 
     252              :   // Allocate storage for solve.
     253              :   void Initialize() const override;
     254              :   void Update(int j) const override;
     255              : 
     256              : public:
     257            0 :   FgmresSolver(MPI_Comm comm, int print) : GmresSolver<OperType>(comm, print)
     258              :   {
     259            0 :     pc_side = PreconditionerSide::RIGHT;
     260            0 :   }
     261              : 
     262            0 :   void SetPreconditionerSide(const PreconditionerSide side) override
     263              :   {
     264            0 :     MFEM_VERIFY(side == PreconditionerSide::RIGHT,
     265              :                 "FGMRES solver only supports right preconditioning!");
     266            0 :   }
     267              : 
     268              :   void Mult(const VecType &b, VecType &x) const override;
     269              : };
     270              : 
     271              : }  // namespace palace
     272              : 
     273              : #endif  // PALACE_LINALG_ITERATIVE_HPP
        

Generated by: LCOV version 2.0-1