Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #ifndef PALACE_LINALG_ITERATIVE_HPP
5 : #define PALACE_LINALG_ITERATIVE_HPP
6 :
7 : #include <type_traits>
8 : #include <vector>
9 : #include <mfem.hpp>
10 : #include "linalg/operator.hpp"
11 : #include "linalg/solver.hpp"
12 : #include "linalg/vector.hpp"
13 : #include "utils/labels.hpp"
14 :
15 : namespace palace
16 : {
17 :
18 : //
19 : // Iterative solvers based on Krylov subspace methods with optional preconditioning, for
20 : // real- or complex-valued systems.
21 : //
22 :
23 : // Base class for iterative solvers based on Krylov subspace methods with optional
24 : // preconditioning.
25 : template <typename OperType>
26 : class IterativeSolver : public Solver<OperType>
27 : {
28 : protected:
29 : using RealType = double;
30 : using ScalarType =
31 : typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
32 : std::complex<RealType>, RealType>::type;
33 :
34 : // MPI communicator associated with the solver.
35 : MPI_Comm comm;
36 :
37 : // Control level of printing during solves.
38 : mfem::IterativeSolver::PrintLevel print_opts;
39 : int int_width, tab_width;
40 :
41 : // Relative and absolute tolerances.
42 : double rel_tol, abs_tol;
43 :
44 : // Limit for the number of solver iterations.
45 : int max_it;
46 :
47 : // Operator and (optional) preconditioner associated with the iterative solver (not
48 : // owned).
49 : const OperType *A;
50 : const Solver<OperType> *B;
51 :
52 : // Variables set during solve to capture solve statistics.
53 : mutable bool converged;
54 : mutable double initial_res, final_res;
55 : mutable int final_it;
56 :
57 : // Enable timer contribution for Timer::PRECONDITIONER.
58 : bool use_timer;
59 :
60 : public:
61 : IterativeSolver(MPI_Comm comm, int print);
62 :
63 : // Set an indentation for all log printing.
64 0 : void SetTabWidth(int width) { tab_width = width; }
65 :
66 : // Set the relative convergence tolerance.
67 0 : void SetTol(double tol) { SetRelTol(tol); }
68 0 : void SetRelTol(double tol) { rel_tol = tol; }
69 :
70 : // Set the absolute convergence tolerance.
71 0 : void SetAbsTol(double tol) { abs_tol = tol; }
72 :
73 : // Set the maximum number of iterations.
74 0 : void SetMaxIter(int its)
75 : {
76 0 : max_it = its;
77 0 : int_width = 1 + static_cast<int>(std::log10(its));
78 0 : }
79 :
80 : // Set the operator for the solver.
81 0 : void SetOperator(const OperType &op) override
82 : {
83 0 : A = &op;
84 0 : this->height = op.Height();
85 0 : this->width = op.Width();
86 0 : }
87 :
88 : // Set the preconditioner for the solver.
89 0 : void SetPreconditioner(const Solver<OperType> &pc) { B = &pc; }
90 :
91 : // Returns if the previous solve converged or not.
92 0 : bool GetConverged() const { return converged && (rel_tol > 0.0 || abs_tol > 0.0); }
93 :
94 : // Returns the initial (absolute) residual for the previous solve.
95 0 : double GetInitialRes() const { return initial_res; }
96 :
97 : // Returns the final (absolute) residual for the previous solve, which may be an estimate
98 : // to the true residual.
99 0 : double GetFinalRes() const { return final_res; }
100 :
101 : // Returns the number of iterations for the previous solve.
102 0 : int GetNumIterations() const { return final_it; }
103 :
104 : // Get the associated MPI communicator.
105 0 : MPI_Comm GetComm() const { return comm; }
106 :
107 : // Activate preconditioner timing during solves.
108 0 : void EnableTimer() { use_timer = true; }
109 : };
110 :
111 : // Preconditioned Conjugate Gradient (CG) method for SPD linear systems.
112 : template <typename OperType>
113 : class CgSolver : public IterativeSolver<OperType>
114 : {
115 : protected:
116 : using VecType = typename Solver<OperType>::VecType;
117 : using RealType = typename IterativeSolver<OperType>::RealType;
118 : using ScalarType = typename IterativeSolver<OperType>::ScalarType;
119 :
120 : using IterativeSolver<OperType>::comm;
121 : using IterativeSolver<OperType>::print_opts;
122 : using IterativeSolver<OperType>::int_width;
123 : using IterativeSolver<OperType>::tab_width;
124 :
125 : using IterativeSolver<OperType>::rel_tol;
126 : using IterativeSolver<OperType>::abs_tol;
127 : using IterativeSolver<OperType>::max_it;
128 :
129 : using IterativeSolver<OperType>::A;
130 : using IterativeSolver<OperType>::B;
131 :
132 : using IterativeSolver<OperType>::converged;
133 : using IterativeSolver<OperType>::initial_res;
134 : using IterativeSolver<OperType>::final_res;
135 : using IterativeSolver<OperType>::final_it;
136 :
137 : // Temporary workspace for solve.
138 : mutable VecType r, z, p;
139 :
140 : public:
141 0 : CgSolver(MPI_Comm comm, int print) : IterativeSolver<OperType>(comm, print) {}
142 :
143 : void Mult(const VecType &b, VecType &x) const override;
144 : };
145 :
146 : // Preconditioned Generalized Minimum Residual Method (GMRES) for general nonsymmetric
147 : // linear systems.
148 : template <typename OperType>
149 : class GmresSolver : public IterativeSolver<OperType>
150 : {
151 : protected:
152 : using VecType = typename Solver<OperType>::VecType;
153 : using RealType = typename IterativeSolver<OperType>::RealType;
154 : using ScalarType = typename IterativeSolver<OperType>::ScalarType;
155 :
156 : using IterativeSolver<OperType>::comm;
157 : using IterativeSolver<OperType>::print_opts;
158 : using IterativeSolver<OperType>::int_width;
159 : using IterativeSolver<OperType>::tab_width;
160 :
161 : using IterativeSolver<OperType>::rel_tol;
162 : using IterativeSolver<OperType>::abs_tol;
163 : using IterativeSolver<OperType>::max_it;
164 :
165 : using IterativeSolver<OperType>::A;
166 : using IterativeSolver<OperType>::B;
167 :
168 : using IterativeSolver<OperType>::converged;
169 : using IterativeSolver<OperType>::initial_res;
170 : using IterativeSolver<OperType>::final_res;
171 : using IterativeSolver<OperType>::final_it;
172 :
173 : // Maximum subspace dimension for restarted GMRES.
174 : mutable int max_dim;
175 :
176 : // Orthogonalization method for orthonormalizing a newly computed vector against a basis
177 : // at each iteration.
178 : Orthogonalization gs_orthog;
179 :
180 : // Use left or right preconditioning.
181 : PreconditionerSide pc_side;
182 :
183 : // Temporary workspace for solve.
184 : mutable std::vector<VecType> V;
185 : mutable VecType r;
186 : mutable std::vector<ScalarType> H;
187 : mutable std::vector<ScalarType> s, sn;
188 : mutable std::vector<RealType> cs;
189 :
190 : // Allocate storage for solve.
191 : virtual void Initialize() const;
192 : virtual void Update(int j) const;
193 :
194 : public:
195 0 : GmresSolver(MPI_Comm comm, int print)
196 0 : : IterativeSolver<OperType>(comm, print), max_dim(-1),
197 0 : gs_orthog(Orthogonalization::MGS), pc_side(PreconditionerSide::LEFT)
198 : {
199 0 : }
200 :
201 : // Set the dimension for restart.
202 0 : void SetRestartDim(int dim) { max_dim = dim; }
203 :
204 : // Set the orthogonalization method.
205 0 : void SetOrthogonalization(Orthogonalization orthog) { gs_orthog = orthog; }
206 :
207 : // Set the side for preconditioning.
208 0 : virtual void SetPreconditionerSide(PreconditionerSide side) { pc_side = side; }
209 :
210 : void Mult(const VecType &b, VecType &x) const override;
211 : };
212 :
213 : // Preconditioned Flexible Generalized Minimum Residual Method (FGMRES) for general
214 : // nonsymmetric linear systems with a non-constant preconditioner.
215 : template <typename OperType>
216 : class FgmresSolver : public GmresSolver<OperType>
217 : {
218 : protected:
219 : using VecType = typename GmresSolver<OperType>::VecType;
220 : using RealType = typename GmresSolver<OperType>::RealType;
221 : using ScalarType = typename GmresSolver<OperType>::ScalarType;
222 :
223 : using GmresSolver<OperType>::comm;
224 : using GmresSolver<OperType>::print_opts;
225 : using GmresSolver<OperType>::int_width;
226 : using GmresSolver<OperType>::tab_width;
227 :
228 : using GmresSolver<OperType>::rel_tol;
229 : using GmresSolver<OperType>::abs_tol;
230 : using GmresSolver<OperType>::max_it;
231 :
232 : using GmresSolver<OperType>::A;
233 : using GmresSolver<OperType>::B;
234 :
235 : using GmresSolver<OperType>::converged;
236 : using GmresSolver<OperType>::initial_res;
237 : using GmresSolver<OperType>::final_res;
238 : using GmresSolver<OperType>::final_it;
239 :
240 : using GmresSolver<OperType>::max_dim;
241 : using GmresSolver<OperType>::gs_orthog;
242 : using GmresSolver<OperType>::pc_side;
243 : using GmresSolver<OperType>::V;
244 : using GmresSolver<OperType>::H;
245 : using GmresSolver<OperType>::s;
246 : using GmresSolver<OperType>::sn;
247 : using GmresSolver<OperType>::cs;
248 :
249 : // Temporary workspace for solve.
250 : mutable std::vector<VecType> Z;
251 :
252 : // Allocate storage for solve.
253 : void Initialize() const override;
254 : void Update(int j) const override;
255 :
256 : public:
257 0 : FgmresSolver(MPI_Comm comm, int print) : GmresSolver<OperType>(comm, print)
258 : {
259 0 : pc_side = PreconditionerSide::RIGHT;
260 0 : }
261 :
262 0 : void SetPreconditionerSide(const PreconditionerSide side) override
263 : {
264 0 : MFEM_VERIFY(side == PreconditionerSide::RIGHT,
265 : "FGMRES solver only supports right preconditioning!");
266 0 : }
267 :
268 : void Mult(const VecType &b, VecType &x) const override;
269 : };
270 :
271 : } // namespace palace
272 :
273 : #endif // PALACE_LINALG_ITERATIVE_HPP
|