LCOV - code coverage report
Current view: top level - linalg - iterative.cpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 0.0 % 386 0
Test Date: 2025-10-23 22:45:05 Functions: 0.0 % 27 0
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #include "iterative.hpp"
       5              : 
       6              : #include <algorithm>
       7              : #include <cmath>
       8              : #include <limits>
       9              : #include <string>
      10              : #include "linalg/orthog.hpp"
      11              : #include "utils/communication.hpp"
      12              : #include "utils/timer.hpp"
      13              : 
      14              : namespace palace
      15              : {
      16              : 
      17              : namespace
      18              : {
      19              : 
      20              : template <typename T>
      21              : inline void CheckDot(T dot, const char *msg)
      22              : {
      23              :   MFEM_ASSERT(std::isfinite(dot) && dot >= 0.0, msg << dot << "!");
      24              : }
      25              : 
      26              : template <typename T>
      27              : inline void CheckDot(std::complex<T> dot, const char *msg)
      28              : {
      29              :   MFEM_ASSERT(std::isfinite(dot.real()) && std::isfinite(dot.imag()) && dot.real() >= 0.0,
      30              :               msg << dot << "!");
      31              : }
      32              : 
      33              : template <typename T>
      34              : inline constexpr T SafeMin()
      35              : {
      36              :   // Originally part of <T>LAPACK.
      37              :   // <T>LAPACK is free software: you can redistribute it and/or modify it under
      38              :   // the terms of the BSD 3-Clause license.
      39              :   //
      40              :   // Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
      41              :   // Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
      42              :   //
      43              :   // Original author: Weslley S Pereira, University of Colorado Denver, USA
      44              :   constexpr int fradix = std::numeric_limits<T>::radix;
      45              :   constexpr int expm = std::numeric_limits<T>::min_exponent;
      46              :   constexpr int expM = std::numeric_limits<T>::max_exponent;
      47              :   // Note: pow is not constexpr in C++17 so this actually might not return a constexpr for
      48              :   // all compilers.
      49              :   return std::max(std::pow(fradix, T(expm - 1)), std::pow(fradix, T(1 - expM)));
      50              : }
      51              : 
      52              : template <typename T>
      53              : inline constexpr T SafeMax()
      54              : {
      55              :   // Originally part of <T>LAPACK.
      56              :   // <T>LAPACK is free software: you can redistribute it and/or modify it under
      57              :   // the terms of the BSD 3-Clause license.
      58              :   //
      59              :   // Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
      60              :   // Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
      61              :   //
      62              :   // Original author: Weslley S Pereira, University of Colorado Denver, USA
      63              :   constexpr int fradix = std::numeric_limits<T>::radix;
      64              :   constexpr int expm = std::numeric_limits<T>::min_exponent;
      65              :   constexpr int expM = std::numeric_limits<T>::max_exponent;
      66              :   // Note: pow is not constexpr in C++17 so this actually might not return a constexpr for
      67              :   // all compilers.
      68              :   return std::min(std::pow(fradix, T(1 - expm)), std::pow(fradix, T(expM - 1)));
      69              : }
      70              : 
      71              : template <typename T>
      72            0 : inline void GeneratePlaneRotation(const T dx, const T dy, T &cs, T &sn)
      73              : {
      74              :   // See LAPACK's s/dlartg.
      75            0 :   const T safmin = SafeMin<T>();
      76            0 :   const T safmax = SafeMax<T>();
      77              :   const T root_min = std::sqrt(safmin);
      78              :   const T root_max = std::sqrt(safmax / 2);
      79            0 :   if (dy == 0.0)
      80              :   {
      81            0 :     cs = 1.0;
      82            0 :     sn = 0.0;
      83            0 :     return;
      84              :   }
      85            0 :   if (dx == 0.0)
      86              :   {
      87            0 :     cs = 0.0;
      88            0 :     sn = std::copysign(1.0, dy);
      89            0 :     return;
      90              :   }
      91            0 :   T dx1 = std::abs(dx);
      92            0 :   T dy1 = std::abs(dy);
      93            0 :   if (dx1 > root_min && dx1 < root_max && dy1 > root_min && dy1 < root_max)
      94              :   {
      95            0 :     T d = std::sqrt(dx * dx + dy * dy);
      96            0 :     cs = dx1 / d;
      97            0 :     sn = dy / std::copysign(d, dx);
      98            0 :   }
      99              :   else
     100              :   {
     101            0 :     T u = std::min(safmax, std::max(safmin, std::max(dx1, dy1)));
     102            0 :     T dxs = dx / u;
     103            0 :     T dys = dy / u;
     104            0 :     T d = std::sqrt(dxs * dxs + dys * dys);
     105            0 :     cs = std::abs(dxs) / d;
     106            0 :     sn = dys / std::copysign(d, dx);
     107              :   }
     108              : }
     109              : 
     110              : template <typename T>
     111            0 : inline void GeneratePlaneRotation(const std::complex<T> dx, const std::complex<T> dy, T &cs,
     112              :                                   std::complex<T> &sn)
     113              : {
     114              :   // Generates a plane rotation so that:
     115              :   //   [  cs        sn ] [ dx ]  =  [ r ]
     116              :   //   [ -conj(sn)  cs ] [ dy ]     [ 0 ]
     117              :   // where cs is real and cs² + |sn|² = 1. See LAPACK's c/zlartg.
     118            0 :   const T safmin = SafeMin<T>();
     119            0 :   const T safmax = SafeMax<T>();
     120              :   if (dy == 0.0)
     121              :   {
     122            0 :     cs = 1.0;
     123              :     sn = 0.0;
     124            0 :     return;
     125              :   }
     126              :   if (dx == 0.0)
     127              :   {
     128            0 :     cs = 0.0;
     129            0 :     if (dy.real() == 0.0)
     130              :     {
     131            0 :       sn = std::conj(dy) / std::abs(dy.imag());
     132              :     }
     133            0 :     else if (dy.imag() == 0.0)
     134              :     {
     135            0 :       sn = std::conj(dy) / std::abs(dy.real());
     136              :     }
     137              :     else
     138              :     {
     139              :       const T root_min = std::sqrt(safmin);
     140              :       const T root_max = std::sqrt(safmax / 2);
     141            0 :       T dy1 = std::max(std::abs(dy.real()), std::abs(dy.imag()));
     142            0 :       if (dy1 > root_min && dy1 < root_max)
     143              :       {
     144            0 :         sn = std::conj(dy) / std::sqrt(dy.real() * dy.real() + dy.imag() * dy.imag());
     145              :       }
     146              :       else
     147              :       {
     148            0 :         T u = std::min(safmax, std::max(safmin, dy1));
     149              :         std::complex<T> dys = dy / u;
     150            0 :         sn = std::conj(dys) / std::sqrt(dys.real() * dys.real() + dys.imag() * dys.imag());
     151              :       }
     152              :     }
     153            0 :     return;
     154              :   }
     155              :   const T root_min = std::sqrt(safmin);
     156              :   const T root_max = std::sqrt(safmax / 4);
     157            0 :   T dx1 = std::max(std::abs(dx.real()), std::abs(dx.imag()));
     158            0 :   T dy1 = std::max(std::abs(dy.real()), std::abs(dy.imag()));
     159            0 :   if (dx1 > root_min && dx1 < root_max && dy1 > root_min && dy1 < root_max)
     160              :   {
     161            0 :     T dx2 = dx.real() * dx.real() + dx.imag() * dx.imag();
     162            0 :     T dy2 = dy.real() * dy.real() + dy.imag() * dy.imag();
     163            0 :     T dz2 = dx2 + dy2;
     164            0 :     if (dx2 >= dz2 * safmin)
     165              :     {
     166            0 :       cs = std::sqrt(dx2 / dz2);
     167            0 :       if (dx2 > root_min && dz2 < root_max * 2)
     168              :       {
     169            0 :         sn = std::conj(dy) * (dx / std::sqrt(dx2 * dz2));
     170              :       }
     171              :       else
     172              :       {
     173            0 :         sn = std::conj(dy) * ((dx / cs) / dz2);
     174              :       }
     175              :     }
     176              :     else
     177              :     {
     178            0 :       T d = std::sqrt(dx2 * dz2);
     179            0 :       cs = dx2 / d;
     180            0 :       sn = std::conj(dy) * (dx / d);
     181              :     }
     182              :   }
     183              :   else
     184              :   {
     185            0 :     T u = std::min(safmax, std::max(safmin, std::max(dx1, dy1))), w;
     186              :     std::complex<T> dys = dy / u, dxs;
     187            0 :     T dy2 = dys.real() * dys.real() + dys.imag() * dys.imag(), dx2, dz2;
     188            0 :     if (dx1 / u < root_min)
     189              :     {
     190            0 :       T v = std::min(safmax, std::max(safmin, dx1));
     191            0 :       w = v / u;
     192              :       dxs = dx / v;
     193            0 :       dx2 = dxs.real() * dxs.real() + dxs.imag() * dxs.imag();
     194            0 :       dz2 = dx2 * w * w + dy2;
     195              :     }
     196              :     else
     197              :     {
     198              :       w = 1.0;
     199              :       dxs = dx / u;
     200            0 :       dx2 = dxs.real() * dxs.real() + dxs.imag() * dxs.imag();
     201            0 :       dz2 = dx2 + dy2;
     202              :     }
     203            0 :     if (dx2 >= dz2 * safmin)
     204              :     {
     205            0 :       cs = std::sqrt(dx2 / dz2);
     206            0 :       if (dx2 > root_min && dz2 < root_max * 2)
     207              :       {
     208            0 :         sn = std::conj(dys) * (dxs / std::sqrt(dx2 * dz2));
     209              :       }
     210              :       else
     211              :       {
     212            0 :         sn = std::conj(dys) * ((dxs / cs) / dz2);
     213              :       }
     214              :     }
     215              :     else
     216              :     {
     217            0 :       T d = std::sqrt(dx2 * dz2);
     218            0 :       cs = dx2 / d;
     219            0 :       sn = std::conj(dys) * (dxs / d);
     220              :     }
     221            0 :     cs *= w;
     222              :   }
     223              : }
     224              : 
     225              : template <typename T>
     226              : inline void ApplyPlaneRotation(T &dx, T &dy, const T cs, const T sn)
     227              : {
     228            0 :   T t = cs * dx + sn * dy;
     229            0 :   dy = -sn * dx + cs * dy;
     230            0 :   dx = t;
     231              : }
     232              : 
     233              : template <typename T>
     234            0 : inline void ApplyPlaneRotation(std::complex<T> &dx, std::complex<T> &dy, const T cs,
     235              :                                const std::complex<T> sn)
     236              : {
     237              :   std::complex<T> t = cs * dx + sn * dy;
     238            0 :   dy = -std::conj(sn) * dx + cs * dy;
     239            0 :   dx = t;
     240            0 : }
     241              : 
     242              : template <typename OperType, typename VecType>
     243            0 : inline void ApplyB(const Solver<OperType> *B, const VecType &x, VecType &y,
     244              :                    bool use_timer = true)
     245              : {
     246            0 :   BlockTimer bt(Timer::KSP_PRECONDITIONER, use_timer);
     247              :   MFEM_ASSERT(B, "Missing preconditioner in ApplyB!");
     248            0 :   B->Mult(x, y);
     249            0 : }
     250              : 
     251              : template <typename OperType, typename VecType>
     252            0 : inline void InitialResidual(PreconditionerSide side, const OperType *A,
     253              :                             const Solver<OperType> *B, const VecType &b, VecType &x,
     254              :                             VecType &r, VecType &z, bool initial_guess,
     255              :                             bool use_timer = true)
     256              : {
     257            0 :   if (B && side == PreconditionerSide::LEFT)
     258              :   {
     259            0 :     if (initial_guess)
     260              :     {
     261            0 :       A->Mult(x, z);
     262            0 :       linalg::AXPBY(1.0, b, -1.0, z);
     263            0 :       ApplyB(B, z, r, use_timer);
     264              :     }
     265              :     else
     266              :     {
     267            0 :       ApplyB(B, b, r, use_timer);
     268            0 :       x = 0.0;
     269              :     }
     270              :   }
     271              :   else  // !B || side == PreconditionerSide::RIGHT
     272              :   {
     273            0 :     if (initial_guess)
     274              :     {
     275            0 :       A->Mult(x, r);
     276            0 :       linalg::AXPBY(1.0, b, -1.0, r);
     277              :     }
     278              :     else
     279              :     {
     280            0 :       r = b;
     281            0 :       x = 0.0;
     282              :     }
     283              :   }
     284            0 : }
     285              : 
     286              : template <typename OperType, typename VecType>
     287            0 : inline void ApplyBA(PreconditionerSide side, const OperType *A, const Solver<OperType> *B,
     288              :                     const VecType &x, VecType &y, VecType &z, bool use_timer = true)
     289              : {
     290            0 :   if (B && side == PreconditionerSide::LEFT)
     291              :   {
     292            0 :     A->Mult(x, z);
     293            0 :     ApplyB(B, z, y, use_timer);
     294              :   }
     295            0 :   else if (B && side == PreconditionerSide::RIGHT)
     296              :   {
     297            0 :     ApplyB(B, x, z, use_timer);
     298            0 :     A->Mult(z, y);
     299              :   }
     300              :   else
     301              :   {
     302            0 :     A->Mult(x, y);
     303              :   }
     304            0 : }
     305              : 
     306              : template <typename VecType, typename ScalarType>
     307            0 : inline void OrthogonalizeIteration(Orthogonalization type, MPI_Comm comm,
     308              :                                    const std::vector<VecType> &V, VecType &w,
     309              :                                    ScalarType *Hj, int j)
     310              : {
     311              :   // Orthogonalize w against the leading j + 1 columns of V.
     312            0 :   switch (type)
     313              :   {
     314            0 :     case Orthogonalization::MGS:
     315            0 :       linalg::OrthogonalizeColumnMGS(comm, V, w, Hj, j + 1);
     316            0 :       break;
     317            0 :     case Orthogonalization::CGS:
     318            0 :       linalg::OrthogonalizeColumnCGS(comm, V, w, Hj, j + 1);
     319            0 :       break;
     320            0 :     case Orthogonalization::CGS2:
     321            0 :       linalg::OrthogonalizeColumnCGS(comm, V, w, Hj, j + 1, true);
     322            0 :       break;
     323              :   }
     324            0 : }
     325              : 
     326              : }  // namespace
     327              : 
     328              : template <typename OperType>
     329            0 : IterativeSolver<OperType>::IterativeSolver(MPI_Comm comm, int print)
     330            0 :   : Solver<OperType>(), comm(comm), A(nullptr), B(nullptr)
     331              : {
     332              :   print_opts.Warnings();
     333            0 :   if (print > 0)
     334              :   {
     335              :     print_opts.Summary();
     336            0 :     if (print > 1)
     337              :     {
     338              :       print_opts.Iterations();
     339            0 :       if (print > 2)
     340              :       {
     341              :         print_opts.All();
     342              :       }
     343              :     }
     344              :   }
     345            0 :   int_width = 3;
     346            0 :   tab_width = 0;
     347              : 
     348            0 :   rel_tol = abs_tol = 0.0;
     349            0 :   max_it = 100;
     350              : 
     351            0 :   converged = false;
     352            0 :   initial_res = 1.0;
     353            0 :   final_res = 0.0;
     354            0 :   final_it = 0;
     355              : 
     356            0 :   use_timer = false;
     357            0 : }
     358              : 
     359              : template <typename OperType>
     360            0 : void CgSolver<OperType>::Mult(const VecType &b, VecType &x) const
     361              : {
     362              :   // Set up workspace.
     363              :   ScalarType beta, beta_prev = 0.0, alpha, denom;
     364              :   RealType res, eps;
     365            0 :   MFEM_VERIFY(A, "Operator must be set for CgSolver::Mult!");
     366              :   MFEM_ASSERT(A->Width() == x.Size() && A->Height() == b.Size(),
     367              :               "Size mismatch for CgSolver::Mult!");
     368            0 :   r.SetSize(A->Height());
     369            0 :   z.SetSize(A->Height());
     370            0 :   p.SetSize(A->Height());
     371            0 :   r.UseDevice(true);
     372            0 :   z.UseDevice(true);
     373            0 :   p.UseDevice(true);
     374              : 
     375              :   // Initialize.
     376            0 :   if (this->initial_guess)
     377              :   {
     378            0 :     A->Mult(x, r);
     379            0 :     linalg::AXPBY(1.0, b, -1.0, r);
     380              :   }
     381              :   else
     382              :   {
     383            0 :     r = b;
     384            0 :     x = 0.0;
     385              :   }
     386            0 :   if (B)
     387              :   {
     388            0 :     ApplyB(B, r, z, this->use_timer);
     389              :   }
     390              :   else
     391              :   {
     392            0 :     z = r;
     393              :   }
     394            0 :   beta = linalg::Dot(comm, z, r);
     395              :   CheckDot(beta, "PCG preconditioner is not positive definite: (Br, r) = ");
     396            0 :   res = std::sqrt(std::abs(beta));
     397            0 :   if (this->initial_guess)
     398              :   {
     399            0 :     ScalarType beta_rhs;
     400            0 :     if (B)
     401              :     {
     402            0 :       ApplyB(B, b, p, this->use_timer);
     403            0 :       beta_rhs = linalg::Dot(comm, p, b);
     404              :     }
     405              :     else
     406              :     {
     407            0 :       beta_rhs = linalg::Norml2(comm, b);
     408              :     }
     409              :     CheckDot(beta_rhs, "PCG preconditioner is not positive definite: (Bb, b) = ");
     410            0 :     initial_res = std::sqrt(std::abs(beta_rhs));
     411              :   }
     412              :   else
     413              :   {
     414            0 :     initial_res = res;
     415              :   }
     416            0 :   eps = std::max(rel_tol * initial_res, abs_tol);
     417            0 :   converged = (res < eps);
     418              : 
     419              :   // Begin iterations.
     420            0 :   int it = 0;
     421            0 :   if (print_opts.iterations)
     422              :   {
     423            0 :     Mpi::Print(comm, "{}Residual norms for PCG solve\n",
     424            0 :                std::string(tab_width + int_width - 1, ' '));
     425              :   }
     426            0 :   for (; it < max_it && !converged; it++)
     427              :   {
     428            0 :     if (print_opts.iterations)
     429              :     {
     430            0 :       Mpi::Print(comm, "{}{:{}d} KSP residual norm ||r||_B = {:.6e}\n",
     431            0 :                  std::string(tab_width, ' '), it, int_width, res);
     432              :     }
     433            0 :     if (!it)
     434              :     {
     435            0 :       p = z;
     436              :     }
     437              :     else
     438              :     {
     439            0 :       linalg::AXPBY(ScalarType(1.0), z, beta / beta_prev, p);
     440              :     }
     441              : 
     442            0 :     A->Mult(p, z);
     443            0 :     denom = linalg::Dot(comm, z, p);
     444              :     CheckDot(denom, "PCG operator is not positive definite: (Ap, p) = ");
     445            0 :     alpha = beta / denom;
     446              : 
     447            0 :     x.Add(alpha, p);
     448            0 :     r.Add(-alpha, z);
     449              : 
     450              :     beta_prev = beta;
     451            0 :     if (B)
     452              :     {
     453            0 :       ApplyB(B, r, z, this->use_timer);
     454              :     }
     455              :     else
     456              :     {
     457            0 :       z = r;
     458              :     }
     459            0 :     beta = linalg::Dot(comm, z, r);
     460              :     CheckDot(beta, "PCG preconditioner is not positive definite: (Br, r) = ");
     461            0 :     res = std::sqrt(std::abs(beta));
     462            0 :     converged = (res < eps);
     463              :   }
     464            0 :   if (print_opts.iterations)
     465              :   {
     466            0 :     Mpi::Print(comm, "{}{:{}d} KSP residual norm ||r||_B = {:.6e}\n",
     467            0 :                std::string(tab_width, ' '), it, int_width, res);
     468              :   }
     469            0 :   if (print_opts.summary || (print_opts.warnings && eps > 0.0 && !converged))
     470              :   {
     471            0 :     Mpi::Print(comm, "{}PCG solver {} in {:d} iteration{}", std::string(tab_width, ' '),
     472            0 :                converged ? "converged" : "did NOT converge", it, (it == 1) ? "" : "s");
     473            0 :     if (it > 0)
     474              :     {
     475            0 :       Mpi::Print(comm, " (avg. reduction factor: {:.3e})\n",
     476            0 :                  std::pow(res / initial_res, 1.0 / it));
     477              :     }
     478              :     else
     479              :     {
     480            0 :       Mpi::Print(comm, "\n");
     481              :     }
     482              :   }
     483            0 :   final_res = res;
     484            0 :   final_it = it;
     485            0 : }
     486              : 
     487              : template <typename OperType>
     488            0 : void GmresSolver<OperType>::Initialize() const
     489              : {
     490            0 :   if (!V.empty())
     491              :   {
     492              :     MFEM_ASSERT(V.size() == static_cast<std::size_t>(max_dim + 1) &&
     493              :                     V[0].Size() == A->Height(),
     494              :                 "Repeated solves with GmresSolver should not modify the operator size or "
     495              :                 "restart dimension!");
     496            0 :     return;
     497              :   }
     498            0 :   if (max_dim < 0)
     499              :   {
     500            0 :     max_dim = max_it;
     501              :   }
     502            0 :   constexpr int init_size = 5;
     503            0 :   V.resize(max_dim + 1);
     504            0 :   for (int j = 0; j < std::min(init_size, max_dim + 1); j++)
     505              :   {
     506            0 :     V[j].SetSize(A->Height());
     507            0 :     V[j].UseDevice(true);
     508              :   }
     509            0 :   H.resize((max_dim + 1) * max_dim);
     510            0 :   s.resize(max_dim + 1);
     511            0 :   cs.resize(max_dim + 1);
     512            0 :   sn.resize(max_dim + 1);
     513              : }
     514              : 
     515              : template <typename OperType>
     516            0 : void GmresSolver<OperType>::Update(int j) const
     517              : {
     518              :   // Add storage for basis vectors in increments.
     519              :   constexpr int add_size = 10;
     520            0 :   for (int k = j + 1; k < std::min(j + 1 + add_size, max_dim + 1); k++)
     521              :   {
     522            0 :     V[k].SetSize(A->Height());
     523            0 :     V[k].UseDevice(true);
     524              :   }
     525            0 : }
     526              : 
     527              : template <typename OperType>
     528            0 : void GmresSolver<OperType>::Mult(const VecType &b, VecType &x) const
     529              : {
     530              :   // Set up workspace.
     531            0 :   RealType beta = 0.0, true_beta, eps = 0.0;
     532            0 :   MFEM_VERIFY(A, "Operator must be set for GmresSolver::Mult!");
     533              :   MFEM_ASSERT(A->Width() == x.Size() && A->Height() == b.Size(),
     534              :               "Size mismatch for GmresSolver::Mult!");
     535            0 :   r.SetSize(A->Height());
     536            0 :   r.UseDevice(true);
     537            0 :   Initialize();
     538              : 
     539              :   // Begin iterations.
     540            0 :   converged = false;
     541            0 :   int it = 0, restart = 0;
     542            0 :   if (print_opts.iterations)
     543              :   {
     544            0 :     Mpi::Print(comm, "{}Residual norms for GMRES solve\n",
     545            0 :                std::string(tab_width + int_width - 1, ' '));
     546              :   }
     547            0 :   for (; it < max_it; restart++)
     548              :   {
     549              :     // Initialize.
     550            0 :     InitialResidual(pc_side, A, B, b, x, r, V[0], (this->initial_guess || restart > 0),
     551            0 :                     this->use_timer);
     552            0 :     true_beta = linalg::Norml2(comm, r);
     553              :     CheckDot(true_beta, "GMRES residual norm is not valid: beta = ");
     554            0 :     if (it == 0)
     555              :     {
     556            0 :       if (this->initial_guess)
     557              :       {
     558              :         RealType beta_rhs;
     559            0 :         if (B && pc_side == PreconditionerSide::LEFT)
     560              :         {
     561            0 :           ApplyB(B, b, V[0], this->use_timer);
     562            0 :           beta_rhs = linalg::Norml2(comm, V[0]);
     563              :         }
     564              :         else  // !B || pc_side == PreconditionerSide::RIGHT
     565              :         {
     566            0 :           beta_rhs = linalg::Norml2(comm, b);
     567              :         }
     568              :         CheckDot(beta_rhs, "GMRES residual norm is not valid: beta_rhs = ");
     569            0 :         initial_res = beta_rhs;
     570              :       }
     571              :       else
     572              :       {
     573            0 :         initial_res = true_beta;
     574              :       }
     575            0 :       eps = std::max(rel_tol * initial_res, abs_tol);
     576              :     }
     577            0 :     else if (beta > 0.0 && std::abs(beta - true_beta) > 0.1 * true_beta &&
     578            0 :              print_opts.warnings)
     579              :     {
     580            0 :       Mpi::Print(
     581            0 :           comm,
     582              :           "{}GMRES residual at restart ({:.6e}) is far from the residual norm estimate "
     583              :           "from the recursion formula ({:.6e}) (initial residual = {:.6e})\n",
     584            0 :           std::string(tab_width, ' '), true_beta, beta, initial_res);
     585              :     }
     586            0 :     beta = true_beta;
     587            0 :     if (beta < eps)
     588              :     {
     589            0 :       converged = true;
     590            0 :       break;
     591              :     }
     592              : 
     593            0 :     V[0] = 0.0;
     594            0 :     V[0].Add(1.0 / beta, r);
     595              :     std::fill(s.begin(), s.end(), 0.0);
     596            0 :     s[0] = beta;
     597              : 
     598              :     int j = 0;
     599            0 :     for (;; j++, it++)
     600              :     {
     601            0 :       if (print_opts.iterations)
     602              :       {
     603            0 :         Mpi::Print(comm, "{}{:{}d} (restart {:d}) KSP residual norm {:.6e}\n",
     604            0 :                    std::string(tab_width, ' '), it, int_width, restart, beta);
     605              :       }
     606            0 :       VecType &w = V[j + 1];
     607            0 :       if (w.Size() == 0)
     608              :       {
     609            0 :         Update(j);
     610              :       }
     611            0 :       ApplyBA(pc_side, A, B, V[j], w, r, this->use_timer);
     612              : 
     613            0 :       ScalarType *Hj = H.data() + j * (max_dim + 1);
     614            0 :       OrthogonalizeIteration(gs_orthog, comm, V, w, Hj, j);
     615            0 :       Hj[j + 1] = linalg::Norml2(comm, w);
     616            0 :       w *= 1.0 / Hj[j + 1];
     617              : 
     618            0 :       for (int k = 0; k < j; k++)
     619              :       {
     620            0 :         ApplyPlaneRotation(Hj[k], Hj[k + 1], cs[k], sn[k]);
     621              :       }
     622            0 :       GeneratePlaneRotation(Hj[j], Hj[j + 1], cs[j], sn[j]);
     623            0 :       ApplyPlaneRotation(Hj[j], Hj[j + 1], cs[j], sn[j]);
     624            0 :       ApplyPlaneRotation(s[j], s[j + 1], cs[j], sn[j]);
     625              : 
     626            0 :       beta = std::abs(s[j + 1]);
     627              :       CheckDot(beta, "GMRES residual norm is not valid: beta = ");
     628            0 :       converged = (beta < eps);
     629            0 :       if (converged || j + 1 == max_dim || it + 1 == max_it)
     630              :       {
     631            0 :         it++;
     632              :         break;
     633              :       }
     634              :     }
     635              : 
     636              :     // Reconstruct the solution (for restart or due to convergence or maximum iterations).
     637            0 :     for (int i = j; i >= 0; i--)
     638              :     {
     639            0 :       ScalarType *Hi = H.data() + i * (max_dim + 1);
     640            0 :       s[i] /= Hi[i];
     641            0 :       for (int k = i - 1; k >= 0; k--)
     642              :       {
     643            0 :         s[k] -= Hi[k] * s[i];
     644              :       }
     645              :     }
     646            0 :     if (!B || pc_side == PreconditionerSide::LEFT)
     647              :     {
     648            0 :       for (int k = 0; k <= j; k++)
     649              :       {
     650            0 :         x.Add(s[k], V[k]);
     651              :       }
     652              :     }
     653              :     else  // B && pc_side == PreconditionerSide::RIGHT
     654              :     {
     655            0 :       r = 0.0;
     656            0 :       for (int k = 0; k <= j; k++)
     657              :       {
     658            0 :         r.Add(s[k], V[k]);
     659              :       }
     660            0 :       ApplyB(B, r, V[0], this->use_timer);
     661            0 :       x += V[0];
     662              :     }
     663            0 :     if (converged)
     664              :     {
     665              :       break;
     666              :     }
     667              :   }
     668            0 :   if (print_opts.iterations)
     669              :   {
     670            0 :     Mpi::Print(comm, "{}{:{}d} (restart {:d}) KSP residual norm {:.6e}\n",
     671            0 :                std::string(tab_width, ' '), it, int_width, restart, beta);
     672              :   }
     673            0 :   if (print_opts.summary || (print_opts.warnings && eps > 0.0 && !converged))
     674              :   {
     675            0 :     Mpi::Print(comm, "{}GMRES solver {} in {:d} iteration{}", std::string(tab_width, ' '),
     676            0 :                converged ? "converged" : "did NOT converge", it, (it == 1) ? "" : "s");
     677            0 :     if (it > 0)
     678              :     {
     679            0 :       Mpi::Print(comm, " (avg. reduction factor: {:.3e})\n",
     680            0 :                  std::pow(beta / initial_res, 1.0 / it));
     681              :     }
     682              :     else
     683              :     {
     684            0 :       Mpi::Print(comm, "\n");
     685              :     }
     686              :   }
     687            0 :   final_res = beta;
     688            0 :   final_it = it;
     689            0 : }
     690              : 
     691              : template <typename OperType>
     692            0 : void FgmresSolver<OperType>::Initialize() const
     693              : {
     694            0 :   GmresSolver<OperType>::Initialize();
     695            0 :   constexpr int init_size = 5;
     696            0 :   Z.resize(max_dim + 1);
     697            0 :   for (int j = 0; j < std::min(init_size, max_dim + 1); j++)
     698              :   {
     699            0 :     Z[j].SetSize(A->Height());
     700            0 :     Z[j].UseDevice(true);
     701              :   }
     702            0 : }
     703              : 
     704              : template <typename OperType>
     705            0 : void FgmresSolver<OperType>::Update(int j) const
     706              : {
     707              :   // Add storage for basis vectors in increments.
     708            0 :   GmresSolver<OperType>::Update(j);
     709              :   constexpr int add_size = 10;
     710            0 :   for (int k = j + 1; k < std::min(j + 1 + add_size, max_dim + 1); k++)
     711              :   {
     712            0 :     Z[k].SetSize(A->Height());
     713            0 :     Z[k].UseDevice(true);
     714              :   }
     715            0 : }
     716              : 
     717              : template <typename OperType>
     718            0 : void FgmresSolver<OperType>::Mult(const VecType &b, VecType &x) const
     719              : {
     720              :   // Set up workspace.
     721            0 :   RealType beta = 0.0, true_beta, eps = 0.0;
     722            0 :   MFEM_VERIFY(A && B, "Operator and preconditioner must be set for FgmresSolver::Mult!");
     723              :   MFEM_ASSERT(A->Width() == x.Size() && A->Height() == b.Size(),
     724              :               "Size mismatch for FgmresSolver::Mult!");
     725            0 :   Initialize();
     726              : 
     727              :   // Begin iterations.
     728            0 :   converged = false;
     729            0 :   int it = 0, restart = 0;
     730            0 :   if (print_opts.iterations)
     731              :   {
     732            0 :     Mpi::Print(comm, "{}Residual norms for FGMRES solve\n",
     733            0 :                std::string(tab_width + int_width - 1, ' '));
     734              :   }
     735            0 :   for (; it < max_it; restart++)
     736              :   {
     737              :     // Initialize.
     738            0 :     InitialResidual(PreconditionerSide::RIGHT, A, B, b, x, Z[0], V[0],
     739            0 :                     (this->initial_guess || restart > 0), this->use_timer);
     740            0 :     true_beta = linalg::Norml2(comm, Z[0]);
     741              :     CheckDot(true_beta, "FGMRES residual norm is not valid: beta = ");
     742            0 :     if (it == 0)
     743              :     {
     744            0 :       if (this->initial_guess)
     745              :       {
     746            0 :         auto beta_rhs = linalg::Norml2(comm, b);
     747              :         CheckDot(beta_rhs, "GMRES residual norm is not valid: beta_rhs = ");
     748            0 :         initial_res = beta_rhs;
     749              :       }
     750              :       else
     751              :       {
     752            0 :         initial_res = true_beta;
     753              :       }
     754            0 :       eps = std::max(rel_tol * initial_res, abs_tol);
     755              :     }
     756            0 :     else if (beta > 0.0 && std::abs(beta - true_beta) > 0.1 * true_beta &&
     757            0 :              print_opts.warnings)
     758              :     {
     759            0 :       Mpi::Print(
     760            0 :           comm,
     761              :           "{}FGMRES residual at restart ({:.6e}) is far from the residual norm estimate "
     762              :           "from the recursion formula ({:.6e}) (initial residual = {:.6e})\n",
     763            0 :           std::string(tab_width, ' '), true_beta, beta, initial_res);
     764              :     }
     765            0 :     beta = true_beta;
     766            0 :     if (beta < eps)
     767              :     {
     768            0 :       converged = true;
     769            0 :       break;
     770              :     }
     771              : 
     772            0 :     V[0] = 0.0;
     773            0 :     V[0].Add(1.0 / beta, Z[0]);
     774              :     std::fill(s.begin(), s.end(), 0.0);
     775            0 :     s[0] = beta;
     776              : 
     777              :     int j = 0;
     778            0 :     for (;; j++, it++)
     779              :     {
     780            0 :       if (print_opts.iterations)
     781              :       {
     782            0 :         Mpi::Print(comm, "{}{:{}d} (restart {:d}) KSP residual norm {:.6e}\n",
     783            0 :                    std::string(tab_width, ' '), it, int_width, restart, beta);
     784              :       }
     785            0 :       VecType &w = V[j + 1];
     786            0 :       if (w.Size() == 0)
     787              :       {
     788            0 :         Update(j);
     789              :       }
     790            0 :       ApplyBA(PreconditionerSide::RIGHT, A, B, V[j], w, Z[j], this->use_timer);
     791              : 
     792            0 :       ScalarType *Hj = H.data() + j * (max_dim + 1);
     793            0 :       OrthogonalizeIteration(gs_orthog, comm, V, w, Hj, j);
     794            0 :       Hj[j + 1] = linalg::Norml2(comm, w);
     795            0 :       w *= 1.0 / Hj[j + 1];
     796              : 
     797            0 :       for (int k = 0; k < j; k++)
     798              :       {
     799            0 :         ApplyPlaneRotation(Hj[k], Hj[k + 1], cs[k], sn[k]);
     800              :       }
     801            0 :       GeneratePlaneRotation(Hj[j], Hj[j + 1], cs[j], sn[j]);
     802            0 :       ApplyPlaneRotation(Hj[j], Hj[j + 1], cs[j], sn[j]);
     803            0 :       ApplyPlaneRotation(s[j], s[j + 1], cs[j], sn[j]);
     804              : 
     805            0 :       beta = std::abs(s[j + 1]);
     806              :       CheckDot(beta, "FGMRES residual norm is not valid: beta = ");
     807            0 :       converged = (beta < eps);
     808            0 :       if (converged || j + 1 == max_dim || it + 1 == max_it)
     809              :       {
     810            0 :         it++;
     811              :         break;
     812              :       }
     813              :     }
     814              : 
     815              :     // Reconstruct the solution (for restart or due to convergence or maximum iterations).
     816            0 :     for (int i = j; i >= 0; i--)
     817              :     {
     818            0 :       ScalarType *Hi = H.data() + i * (max_dim + 1);
     819            0 :       s[i] /= Hi[i];
     820            0 :       for (int k = i - 1; k >= 0; k--)
     821              :       {
     822            0 :         s[k] -= Hi[k] * s[i];
     823              :       }
     824              :     }
     825            0 :     for (int k = 0; k <= j; k++)
     826              :     {
     827            0 :       x.Add(s[k], Z[k]);
     828              :     }
     829            0 :     if (converged)
     830              :     {
     831              :       break;
     832              :     }
     833              :   }
     834            0 :   if (print_opts.iterations)
     835              :   {
     836            0 :     Mpi::Print(comm, "{}{:{}d} (restart {:d}) KSP residual norm {:.6e}\n",
     837            0 :                std::string(tab_width, ' '), it, int_width, restart, beta);
     838              :   }
     839            0 :   if (print_opts.summary || (print_opts.warnings && eps > 0.0 && !converged))
     840              :   {
     841            0 :     Mpi::Print(comm, "{}FGMRES solver {} in {:d} iteration{}", std::string(tab_width, ' '),
     842            0 :                converged ? "converged" : "did NOT converge", it, (it == 1) ? "" : "s");
     843            0 :     if (it > 0)
     844              :     {
     845            0 :       Mpi::Print(comm, " (avg. reduction factor: {:.3e})\n",
     846            0 :                  std::pow(beta / initial_res, 1.0 / it));
     847              :     }
     848              :     else
     849              :     {
     850            0 :       Mpi::Print(comm, "\n");
     851              :     }
     852              :   }
     853            0 :   final_res = beta;
     854            0 :   final_it = it;
     855            0 : }
     856              : 
     857              : template class IterativeSolver<Operator>;
     858              : template class IterativeSolver<ComplexOperator>;
     859              : template class CgSolver<Operator>;
     860              : template class CgSolver<ComplexOperator>;
     861              : template class GmresSolver<Operator>;
     862              : template class GmresSolver<ComplexOperator>;
     863              : template class FgmresSolver<Operator>;
     864              : template class FgmresSolver<ComplexOperator>;
     865              : 
     866              : }  // namespace palace
        

Generated by: LCOV version 2.0-1