LCOV - code coverage report
Current view: top level - linalg - ksp.cpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 0.0 % 97 0
Test Date: 2025-10-23 22:45:05 Functions: 0.0 % 22 0
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #include "ksp.hpp"
       5              : 
       6              : #include <mfem.hpp>
       7              : #include "fem/fespace.hpp"
       8              : #include "linalg/amg.hpp"
       9              : #include "linalg/ams.hpp"
      10              : #include "linalg/gmg.hpp"
      11              : #include "linalg/jacobi.hpp"
      12              : #include "linalg/mumps.hpp"
      13              : #include "linalg/strumpack.hpp"
      14              : #include "linalg/superlu.hpp"
      15              : #include "utils/communication.hpp"
      16              : #include "utils/iodata.hpp"
      17              : #include "utils/timer.hpp"
      18              : 
      19              : namespace palace
      20              : {
      21              : 
      22              : namespace
      23              : {
      24              : 
      25              : template <typename OperType>
      26            0 : std::unique_ptr<IterativeSolver<OperType>> ConfigureKrylovSolver(const IoData &iodata,
      27              :                                                                  MPI_Comm comm)
      28              : {
      29              :   // Create the solver.
      30            0 :   std::unique_ptr<IterativeSolver<OperType>> ksp;
      31            0 :   const auto type = iodata.solver.linear.krylov_solver;
      32            0 :   const int print = iodata.problem.verbose;
      33            0 :   switch (type)
      34              :   {
      35            0 :     case KrylovSolver::CG:
      36            0 :       ksp = std::make_unique<CgSolver<OperType>>(comm, print);
      37            0 :       break;
      38            0 :     case KrylovSolver::GMRES:
      39              :       {
      40            0 :         auto gmres = std::make_unique<GmresSolver<OperType>>(comm, print);
      41            0 :         gmres->SetRestartDim(iodata.solver.linear.max_size);
      42              :         ksp = std::move(gmres);
      43              :       }
      44            0 :       break;
      45            0 :     case KrylovSolver::FGMRES:
      46              :       {
      47            0 :         auto fgmres = std::make_unique<FgmresSolver<OperType>>(comm, print);
      48            0 :         fgmres->SetRestartDim(iodata.solver.linear.max_size);
      49              :         ksp = std::move(fgmres);
      50              :       }
      51            0 :       break;
      52            0 :     case KrylovSolver::MINRES:
      53              :     case KrylovSolver::BICGSTAB:
      54              :     case KrylovSolver::DEFAULT:
      55            0 :       MFEM_ABORT("Unexpected solver type for Krylov solver configuration!");
      56              :       break;
      57              :   }
      58            0 :   ksp->SetInitialGuess(iodata.solver.linear.initial_guess);
      59            0 :   ksp->SetRelTol(iodata.solver.linear.tol);
      60            0 :   ksp->SetMaxIter(iodata.solver.linear.max_it);
      61              : 
      62              :   // Configure preconditioning side (only for GMRES).
      63            0 :   if (iodata.solver.linear.pc_side != PreconditionerSide::DEFAULT &&
      64              :       type != KrylovSolver::GMRES)
      65              :   {
      66            0 :     Mpi::Warning(comm,
      67              :                  "Preconditioner side will be ignored for non-GMRES iterative solvers!\n");
      68              :   }
      69              :   else
      70              :   {
      71            0 :     if (type == KrylovSolver::GMRES || type == KrylovSolver::FGMRES)
      72              :     {
      73              :       auto *gmres = static_cast<GmresSolver<OperType> *>(ksp.get());
      74            0 :       switch (iodata.solver.linear.pc_side)
      75              :       {
      76            0 :         case PreconditionerSide::LEFT:
      77            0 :           gmres->SetPreconditionerSide(PreconditionerSide::LEFT);
      78              :           break;
      79            0 :         case PreconditionerSide::RIGHT:
      80            0 :           gmres->SetPreconditionerSide(PreconditionerSide::RIGHT);
      81              :           break;
      82              :         case PreconditionerSide::DEFAULT:
      83              :           // Do nothing. Set in ctors.
      84              :           break;
      85              :       }
      86              :     }
      87              :   }
      88              : 
      89              :   // Configure orthogonalization method for GMRES/FMGRES.
      90            0 :   if (type == KrylovSolver::GMRES || type == KrylovSolver::FGMRES)
      91              :   {
      92              :     // Because FGMRES inherits from GMRES, this is OK.
      93              :     auto *gmres = static_cast<GmresSolver<OperType> *>(ksp.get());
      94            0 :     gmres->SetOrthogonalization(iodata.solver.linear.gs_orthog);
      95              :   }
      96              : 
      97              :   // Configure timing for the primary linear solver.
      98              :   ksp->EnableTimer();
      99              : 
     100            0 :   return ksp;
     101              : }
     102              : 
     103              : template <typename OperType, typename T, typename... U>
     104            0 : auto MakeWrapperSolver(const IoData &iodata, U &&...args)
     105              : {
     106              :   // Sparse direct solver types copy the input matrix, so there is no need to save the
     107              :   // parallel assembled operator.
     108            0 :   constexpr bool save_assembled = !(false ||
     109              : #if defined(MFEM_USE_SUPERLU)
     110              :                                     std::is_same<T, SuperLUSolver>::value ||
     111              : #endif
     112              : #if defined(MFEM_USE_STRUMPACK)
     113              :                                     std::is_same<T, StrumpackSolver>::value ||
     114              :                                     std::is_same<T, StrumpackMixedPrecisionSolver>::value ||
     115              : #endif
     116              : #if defined(MFEM_USE_MUMPS)
     117              :                                     std::is_same<T, MumpsSolver>::value ||
     118              : #endif
     119              :                                     false);
     120              :   return std::make_unique<MfemWrapperSolver<OperType>>(
     121            0 :       std::make_unique<T>(iodata, std::forward<U>(args)...), save_assembled,
     122            0 :       iodata.solver.linear.complex_coarse_solve);
     123              : }
     124              : 
     125              : template <typename OperType>
     126              : std::unique_ptr<Solver<OperType>>
     127            0 : ConfigurePreconditionerSolver(const IoData &iodata, MPI_Comm comm,
     128              :                               FiniteElementSpaceHierarchy &fespaces,
     129              :                               FiniteElementSpaceHierarchy *aux_fespaces)
     130              : {
     131              :   // Create the real-valued solver first.
     132            0 :   std::unique_ptr<Solver<OperType>> pc;
     133            0 :   const auto type = iodata.solver.linear.type;
     134            0 :   const int print = iodata.problem.verbose - 1;
     135            0 :   switch (type)
     136              :   {
     137            0 :     case LinearSolver::AMS:
     138              :       // Can either be the coarse solve for geometric multigrid or the solver at the finest
     139              :       // space (in which case fespaces.GetNumLevels() == 1).
     140            0 :       MFEM_VERIFY(aux_fespaces, "AMS solver relies on both primary space "
     141              :                                 "and auxiliary spaces for construction!");
     142            0 :       pc = MakeWrapperSolver<OperType, HypreAmsSolver>(
     143            0 :           iodata, fespaces.GetNumLevels() > 1, fespaces.GetFESpaceAtLevel(0),
     144              :           aux_fespaces->GetFESpaceAtLevel(0), print);
     145            0 :       break;
     146              :     case LinearSolver::BOOMER_AMG:
     147            0 :       pc = MakeWrapperSolver<OperType, BoomerAmgSolver>(iodata, fespaces.GetNumLevels() > 1,
     148              :                                                         print);
     149            0 :       break;
     150            0 :     case LinearSolver::SUPERLU:
     151              : #if defined(MFEM_USE_SUPERLU)
     152              :       pc = MakeWrapperSolver<OperType, SuperLUSolver>(iodata, comm, print);
     153              : #else
     154            0 :       MFEM_ABORT("Solver was not built with SuperLU_DIST support, please choose a "
     155              :                  "different solver!");
     156              : #endif
     157              :       break;
     158            0 :     case LinearSolver::STRUMPACK:
     159              : #if defined(MFEM_USE_STRUMPACK)
     160            0 :       pc = MakeWrapperSolver<OperType, StrumpackSolver>(iodata, comm, print);
     161              : #else
     162              :       MFEM_ABORT("Solver was not built with STRUMPACK support, please choose a "
     163              :                  "different solver!");
     164              : #endif
     165            0 :       break;
     166            0 :     case LinearSolver::STRUMPACK_MP:
     167              : #if defined(MFEM_USE_STRUMPACK)
     168            0 :       pc = MakeWrapperSolver<OperType, StrumpackMixedPrecisionSolver>(iodata, comm, print);
     169              : #else
     170              :       MFEM_ABORT("Solver was not built with STRUMPACK support, please choose a "
     171              :                  "different solver!");
     172              : #endif
     173            0 :       break;
     174            0 :     case LinearSolver::MUMPS:
     175              : #if defined(MFEM_USE_MUMPS)
     176              :       pc = MakeWrapperSolver<OperType, MumpsSolver>(iodata, comm, print);
     177              : #else
     178            0 :       MFEM_ABORT(
     179              :           "Solver was not built with MUMPS support, please choose a different solver!");
     180              : #endif
     181              :       break;
     182            0 :     case LinearSolver::JACOBI:
     183            0 :       pc = std::make_unique<JacobiSmoother<OperType>>(comm);
     184            0 :       break;
     185            0 :     case LinearSolver::DEFAULT:
     186            0 :       MFEM_ABORT("Unexpected solver type for preconditioner configuration!");
     187              :       break;
     188              :   }
     189              : 
     190              :   // Construct the actual solver, which has the right value type.
     191            0 :   if (fespaces.GetNumLevels() > 1)
     192              :   {
     193              :     // This will construct the multigrid hierarchy using pc as the coarse solver
     194              :     // (ownership of pc is transferred to the GeometricMultigridSolver). When a special
     195              :     // auxiliary space smoother for pre-/post-smoothing is not desired, the auxiliary
     196              :     // space is a nullptr here.
     197            0 :     auto gmg = [&]()
     198              :     {
     199            0 :       if (iodata.solver.linear.mg_smooth_aux)
     200              :       {
     201            0 :         MFEM_VERIFY(aux_fespaces, "Multigrid with auxiliary space smoothers requires both "
     202              :                                   "primary space and auxiliary spaces for construction!");
     203            0 :         const auto G = fespaces.GetDiscreteInterpolators(*aux_fespaces);
     204              :         return std::make_unique<GeometricMultigridSolver<OperType>>(
     205            0 :             iodata, comm, std::move(pc), fespaces.GetProlongationOperators(), &G);
     206              :       }
     207              :       else
     208              :       {
     209              :         return std::make_unique<GeometricMultigridSolver<OperType>>(
     210            0 :             iodata, comm, std::move(pc), fespaces.GetProlongationOperators());
     211              :       }
     212              :     }();
     213              :     gmg->EnableTimer();  // Enable timing for primary geometric multigrid solver
     214              :     return gmg;
     215              :   }
     216              :   else
     217              :   {
     218              :     return pc;
     219              :   }
     220              : }
     221              : 
     222              : }  // namespace
     223              : 
     224              : template <typename OperType>
     225            0 : BaseKspSolver<OperType>::BaseKspSolver(const IoData &iodata,
     226              :                                        FiniteElementSpaceHierarchy &fespaces,
     227              :                                        FiniteElementSpaceHierarchy *aux_fespaces)
     228              :   : BaseKspSolver(
     229            0 :         ConfigureKrylovSolver<OperType>(iodata, fespaces.GetFinestFESpace().GetComm()),
     230            0 :         ConfigurePreconditionerSolver<OperType>(
     231              :             iodata, fespaces.GetFinestFESpace().GetComm(), fespaces, aux_fespaces))
     232              : {
     233            0 :   use_timer = true;
     234            0 : }
     235              : 
     236              : template <typename OperType>
     237            0 : BaseKspSolver<OperType>::BaseKspSolver(std::unique_ptr<IterativeSolver<OperType>> &&ksp,
     238              :                                        std::unique_ptr<Solver<OperType>> &&pc)
     239            0 :   : ksp(std::move(ksp)), pc(std::move(pc)), ksp_mult(0), ksp_mult_it(0), use_timer(false)
     240              : {
     241            0 :   if (this->pc)
     242              :   {
     243              :     this->ksp->SetPreconditioner(*this->pc);
     244              :   }
     245            0 : }
     246              : 
     247              : template <typename OperType>
     248            0 : void BaseKspSolver<OperType>::SetOperators(const OperType &op, const OperType &pc_op)
     249              : {
     250            0 :   BlockTimer bt(Timer::KSP_SETUP, use_timer);
     251            0 :   ksp->SetOperator(op);
     252            0 :   if (pc)
     253              :   {
     254            0 :     const auto *mg_op = dynamic_cast<const BaseMultigridOperator<OperType> *>(&pc_op);
     255            0 :     const auto *mg_pc = dynamic_cast<const GeometricMultigridSolver<OperType> *>(pc.get());
     256            0 :     if (mg_op && !mg_pc)
     257              :     {
     258            0 :       pc->SetOperator(mg_op->GetFinestOperator());
     259              :     }
     260              :     else
     261              :     {
     262            0 :       pc->SetOperator(pc_op);
     263              :     }
     264              :   }
     265            0 : }
     266              : 
     267              : template <typename OperType>
     268            0 : void BaseKspSolver<OperType>::Mult(const VecType &x, VecType &y) const
     269              : {
     270            0 :   BlockTimer bt(Timer::KSP, use_timer);
     271            0 :   ksp->Mult(x, y);
     272              :   if (!ksp->GetConverged())
     273              :   {
     274            0 :     Mpi::Warning(
     275              :         ksp->GetComm(),
     276              :         "Linear solver did not converge, norm(Ax-b)/norm(b) = {:.3e} (norm(b) = {:.3e})!\n",
     277            0 :         ksp->GetFinalRes() / ksp->GetInitialRes(), ksp->GetInitialRes());
     278              :   }
     279            0 :   ksp_mult++;
     280            0 :   ksp_mult_it += ksp->GetNumIterations();
     281            0 : }
     282              : 
     283              : template class BaseKspSolver<Operator>;
     284              : template class BaseKspSolver<ComplexOperator>;
     285              : 
     286              : }  // namespace palace
        

Generated by: LCOV version 2.0-1