Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #include "iterative.hpp"
5 :
6 : #include <algorithm>
7 : #include <cmath>
8 : #include <limits>
9 : #include <string>
10 : #include "linalg/orthog.hpp"
11 : #include "utils/communication.hpp"
12 : #include "utils/timer.hpp"
13 :
14 : namespace palace
15 : {
16 :
17 : namespace
18 : {
19 :
20 : template <typename T>
21 : inline void CheckDot(T dot, const char *msg)
22 : {
23 : MFEM_ASSERT(std::isfinite(dot) && dot >= 0.0, msg << dot << "!");
24 : }
25 :
26 : template <typename T>
27 : inline void CheckDot(std::complex<T> dot, const char *msg)
28 : {
29 : MFEM_ASSERT(std::isfinite(dot.real()) && std::isfinite(dot.imag()) && dot.real() >= 0.0,
30 : msg << dot << "!");
31 : }
32 :
33 : template <typename T>
34 : inline constexpr T SafeMin()
35 : {
36 : // Originally part of <T>LAPACK.
37 : // <T>LAPACK is free software: you can redistribute it and/or modify it under
38 : // the terms of the BSD 3-Clause license.
39 : //
40 : // Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
41 : // Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
42 : //
43 : // Original author: Weslley S Pereira, University of Colorado Denver, USA
44 : constexpr int fradix = std::numeric_limits<T>::radix;
45 : constexpr int expm = std::numeric_limits<T>::min_exponent;
46 : constexpr int expM = std::numeric_limits<T>::max_exponent;
47 : // Note: pow is not constexpr in C++17 so this actually might not return a constexpr for
48 : // all compilers.
49 : return std::max(std::pow(fradix, T(expm - 1)), std::pow(fradix, T(1 - expM)));
50 : }
51 :
52 : template <typename T>
53 : inline constexpr T SafeMax()
54 : {
55 : // Originally part of <T>LAPACK.
56 : // <T>LAPACK is free software: you can redistribute it and/or modify it under
57 : // the terms of the BSD 3-Clause license.
58 : //
59 : // Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
60 : // Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
61 : //
62 : // Original author: Weslley S Pereira, University of Colorado Denver, USA
63 : constexpr int fradix = std::numeric_limits<T>::radix;
64 : constexpr int expm = std::numeric_limits<T>::min_exponent;
65 : constexpr int expM = std::numeric_limits<T>::max_exponent;
66 : // Note: pow is not constexpr in C++17 so this actually might not return a constexpr for
67 : // all compilers.
68 : return std::min(std::pow(fradix, T(1 - expm)), std::pow(fradix, T(expM - 1)));
69 : }
70 :
71 : template <typename T>
72 0 : inline void GeneratePlaneRotation(const T dx, const T dy, T &cs, T &sn)
73 : {
74 : // See LAPACK's s/dlartg.
75 0 : const T safmin = SafeMin<T>();
76 0 : const T safmax = SafeMax<T>();
77 : const T root_min = std::sqrt(safmin);
78 : const T root_max = std::sqrt(safmax / 2);
79 0 : if (dy == 0.0)
80 : {
81 0 : cs = 1.0;
82 0 : sn = 0.0;
83 0 : return;
84 : }
85 0 : if (dx == 0.0)
86 : {
87 0 : cs = 0.0;
88 0 : sn = std::copysign(1.0, dy);
89 0 : return;
90 : }
91 0 : T dx1 = std::abs(dx);
92 0 : T dy1 = std::abs(dy);
93 0 : if (dx1 > root_min && dx1 < root_max && dy1 > root_min && dy1 < root_max)
94 : {
95 0 : T d = std::sqrt(dx * dx + dy * dy);
96 0 : cs = dx1 / d;
97 0 : sn = dy / std::copysign(d, dx);
98 0 : }
99 : else
100 : {
101 0 : T u = std::min(safmax, std::max(safmin, std::max(dx1, dy1)));
102 0 : T dxs = dx / u;
103 0 : T dys = dy / u;
104 0 : T d = std::sqrt(dxs * dxs + dys * dys);
105 0 : cs = std::abs(dxs) / d;
106 0 : sn = dys / std::copysign(d, dx);
107 : }
108 : }
109 :
110 : template <typename T>
111 0 : inline void GeneratePlaneRotation(const std::complex<T> dx, const std::complex<T> dy, T &cs,
112 : std::complex<T> &sn)
113 : {
114 : // Generates a plane rotation so that:
115 : // [ cs sn ] [ dx ] = [ r ]
116 : // [ -conj(sn) cs ] [ dy ] [ 0 ]
117 : // where cs is real and cs² + |sn|² = 1. See LAPACK's c/zlartg.
118 0 : const T safmin = SafeMin<T>();
119 0 : const T safmax = SafeMax<T>();
120 : if (dy == 0.0)
121 : {
122 0 : cs = 1.0;
123 : sn = 0.0;
124 0 : return;
125 : }
126 : if (dx == 0.0)
127 : {
128 0 : cs = 0.0;
129 0 : if (dy.real() == 0.0)
130 : {
131 0 : sn = std::conj(dy) / std::abs(dy.imag());
132 : }
133 0 : else if (dy.imag() == 0.0)
134 : {
135 0 : sn = std::conj(dy) / std::abs(dy.real());
136 : }
137 : else
138 : {
139 : const T root_min = std::sqrt(safmin);
140 : const T root_max = std::sqrt(safmax / 2);
141 0 : T dy1 = std::max(std::abs(dy.real()), std::abs(dy.imag()));
142 0 : if (dy1 > root_min && dy1 < root_max)
143 : {
144 0 : sn = std::conj(dy) / std::sqrt(dy.real() * dy.real() + dy.imag() * dy.imag());
145 : }
146 : else
147 : {
148 0 : T u = std::min(safmax, std::max(safmin, dy1));
149 : std::complex<T> dys = dy / u;
150 0 : sn = std::conj(dys) / std::sqrt(dys.real() * dys.real() + dys.imag() * dys.imag());
151 : }
152 : }
153 0 : return;
154 : }
155 : const T root_min = std::sqrt(safmin);
156 : const T root_max = std::sqrt(safmax / 4);
157 0 : T dx1 = std::max(std::abs(dx.real()), std::abs(dx.imag()));
158 0 : T dy1 = std::max(std::abs(dy.real()), std::abs(dy.imag()));
159 0 : if (dx1 > root_min && dx1 < root_max && dy1 > root_min && dy1 < root_max)
160 : {
161 0 : T dx2 = dx.real() * dx.real() + dx.imag() * dx.imag();
162 0 : T dy2 = dy.real() * dy.real() + dy.imag() * dy.imag();
163 0 : T dz2 = dx2 + dy2;
164 0 : if (dx2 >= dz2 * safmin)
165 : {
166 0 : cs = std::sqrt(dx2 / dz2);
167 0 : if (dx2 > root_min && dz2 < root_max * 2)
168 : {
169 0 : sn = std::conj(dy) * (dx / std::sqrt(dx2 * dz2));
170 : }
171 : else
172 : {
173 0 : sn = std::conj(dy) * ((dx / cs) / dz2);
174 : }
175 : }
176 : else
177 : {
178 0 : T d = std::sqrt(dx2 * dz2);
179 0 : cs = dx2 / d;
180 0 : sn = std::conj(dy) * (dx / d);
181 : }
182 : }
183 : else
184 : {
185 0 : T u = std::min(safmax, std::max(safmin, std::max(dx1, dy1))), w;
186 : std::complex<T> dys = dy / u, dxs;
187 0 : T dy2 = dys.real() * dys.real() + dys.imag() * dys.imag(), dx2, dz2;
188 0 : if (dx1 / u < root_min)
189 : {
190 0 : T v = std::min(safmax, std::max(safmin, dx1));
191 0 : w = v / u;
192 : dxs = dx / v;
193 0 : dx2 = dxs.real() * dxs.real() + dxs.imag() * dxs.imag();
194 0 : dz2 = dx2 * w * w + dy2;
195 : }
196 : else
197 : {
198 : w = 1.0;
199 : dxs = dx / u;
200 0 : dx2 = dxs.real() * dxs.real() + dxs.imag() * dxs.imag();
201 0 : dz2 = dx2 + dy2;
202 : }
203 0 : if (dx2 >= dz2 * safmin)
204 : {
205 0 : cs = std::sqrt(dx2 / dz2);
206 0 : if (dx2 > root_min && dz2 < root_max * 2)
207 : {
208 0 : sn = std::conj(dys) * (dxs / std::sqrt(dx2 * dz2));
209 : }
210 : else
211 : {
212 0 : sn = std::conj(dys) * ((dxs / cs) / dz2);
213 : }
214 : }
215 : else
216 : {
217 0 : T d = std::sqrt(dx2 * dz2);
218 0 : cs = dx2 / d;
219 0 : sn = std::conj(dys) * (dxs / d);
220 : }
221 0 : cs *= w;
222 : }
223 : }
224 :
225 : template <typename T>
226 : inline void ApplyPlaneRotation(T &dx, T &dy, const T cs, const T sn)
227 : {
228 0 : T t = cs * dx + sn * dy;
229 0 : dy = -sn * dx + cs * dy;
230 0 : dx = t;
231 : }
232 :
233 : template <typename T>
234 0 : inline void ApplyPlaneRotation(std::complex<T> &dx, std::complex<T> &dy, const T cs,
235 : const std::complex<T> sn)
236 : {
237 : std::complex<T> t = cs * dx + sn * dy;
238 0 : dy = -std::conj(sn) * dx + cs * dy;
239 0 : dx = t;
240 0 : }
241 :
242 : template <typename OperType, typename VecType>
243 0 : inline void ApplyB(const Solver<OperType> *B, const VecType &x, VecType &y,
244 : bool use_timer = true)
245 : {
246 0 : BlockTimer bt(Timer::KSP_PRECONDITIONER, use_timer);
247 : MFEM_ASSERT(B, "Missing preconditioner in ApplyB!");
248 0 : B->Mult(x, y);
249 0 : }
250 :
251 : template <typename OperType, typename VecType>
252 0 : inline void InitialResidual(PreconditionerSide side, const OperType *A,
253 : const Solver<OperType> *B, const VecType &b, VecType &x,
254 : VecType &r, VecType &z, bool initial_guess,
255 : bool use_timer = true)
256 : {
257 0 : if (B && side == PreconditionerSide::LEFT)
258 : {
259 0 : if (initial_guess)
260 : {
261 0 : A->Mult(x, z);
262 0 : linalg::AXPBY(1.0, b, -1.0, z);
263 0 : ApplyB(B, z, r, use_timer);
264 : }
265 : else
266 : {
267 0 : ApplyB(B, b, r, use_timer);
268 0 : x = 0.0;
269 : }
270 : }
271 : else // !B || side == PreconditionerSide::RIGHT
272 : {
273 0 : if (initial_guess)
274 : {
275 0 : A->Mult(x, r);
276 0 : linalg::AXPBY(1.0, b, -1.0, r);
277 : }
278 : else
279 : {
280 0 : r = b;
281 0 : x = 0.0;
282 : }
283 : }
284 0 : }
285 :
286 : template <typename OperType, typename VecType>
287 0 : inline void ApplyBA(PreconditionerSide side, const OperType *A, const Solver<OperType> *B,
288 : const VecType &x, VecType &y, VecType &z, bool use_timer = true)
289 : {
290 0 : if (B && side == PreconditionerSide::LEFT)
291 : {
292 0 : A->Mult(x, z);
293 0 : ApplyB(B, z, y, use_timer);
294 : }
295 0 : else if (B && side == PreconditionerSide::RIGHT)
296 : {
297 0 : ApplyB(B, x, z, use_timer);
298 0 : A->Mult(z, y);
299 : }
300 : else
301 : {
302 0 : A->Mult(x, y);
303 : }
304 0 : }
305 :
306 : template <typename VecType, typename ScalarType>
307 0 : inline void OrthogonalizeIteration(Orthogonalization type, MPI_Comm comm,
308 : const std::vector<VecType> &V, VecType &w,
309 : ScalarType *Hj, int j)
310 : {
311 : // Orthogonalize w against the leading j + 1 columns of V.
312 0 : switch (type)
313 : {
314 0 : case Orthogonalization::MGS:
315 0 : linalg::OrthogonalizeColumnMGS(comm, V, w, Hj, j + 1);
316 0 : break;
317 0 : case Orthogonalization::CGS:
318 0 : linalg::OrthogonalizeColumnCGS(comm, V, w, Hj, j + 1);
319 0 : break;
320 0 : case Orthogonalization::CGS2:
321 0 : linalg::OrthogonalizeColumnCGS(comm, V, w, Hj, j + 1, true);
322 0 : break;
323 : }
324 0 : }
325 :
326 : } // namespace
327 :
328 : template <typename OperType>
329 0 : IterativeSolver<OperType>::IterativeSolver(MPI_Comm comm, int print)
330 0 : : Solver<OperType>(), comm(comm), A(nullptr), B(nullptr)
331 : {
332 : print_opts.Warnings();
333 0 : if (print > 0)
334 : {
335 : print_opts.Summary();
336 0 : if (print > 1)
337 : {
338 : print_opts.Iterations();
339 0 : if (print > 2)
340 : {
341 : print_opts.All();
342 : }
343 : }
344 : }
345 0 : int_width = 3;
346 0 : tab_width = 0;
347 :
348 0 : rel_tol = abs_tol = 0.0;
349 0 : max_it = 100;
350 :
351 0 : converged = false;
352 0 : initial_res = 1.0;
353 0 : final_res = 0.0;
354 0 : final_it = 0;
355 :
356 0 : use_timer = false;
357 0 : }
358 :
359 : template <typename OperType>
360 0 : void CgSolver<OperType>::Mult(const VecType &b, VecType &x) const
361 : {
362 : // Set up workspace.
363 : ScalarType beta, beta_prev = 0.0, alpha, denom;
364 : RealType res, eps;
365 0 : MFEM_VERIFY(A, "Operator must be set for CgSolver::Mult!");
366 : MFEM_ASSERT(A->Width() == x.Size() && A->Height() == b.Size(),
367 : "Size mismatch for CgSolver::Mult!");
368 0 : r.SetSize(A->Height());
369 0 : z.SetSize(A->Height());
370 0 : p.SetSize(A->Height());
371 0 : r.UseDevice(true);
372 0 : z.UseDevice(true);
373 0 : p.UseDevice(true);
374 :
375 : // Initialize.
376 0 : if (this->initial_guess)
377 : {
378 0 : A->Mult(x, r);
379 0 : linalg::AXPBY(1.0, b, -1.0, r);
380 : }
381 : else
382 : {
383 0 : r = b;
384 0 : x = 0.0;
385 : }
386 0 : if (B)
387 : {
388 0 : ApplyB(B, r, z, this->use_timer);
389 : }
390 : else
391 : {
392 0 : z = r;
393 : }
394 0 : beta = linalg::Dot(comm, z, r);
395 : CheckDot(beta, "PCG preconditioner is not positive definite: (Br, r) = ");
396 0 : res = std::sqrt(std::abs(beta));
397 0 : if (this->initial_guess)
398 : {
399 0 : ScalarType beta_rhs;
400 0 : if (B)
401 : {
402 0 : ApplyB(B, b, p, this->use_timer);
403 0 : beta_rhs = linalg::Dot(comm, p, b);
404 : }
405 : else
406 : {
407 0 : beta_rhs = linalg::Norml2(comm, b);
408 : }
409 : CheckDot(beta_rhs, "PCG preconditioner is not positive definite: (Bb, b) = ");
410 0 : initial_res = std::sqrt(std::abs(beta_rhs));
411 : }
412 : else
413 : {
414 0 : initial_res = res;
415 : }
416 0 : eps = std::max(rel_tol * initial_res, abs_tol);
417 0 : converged = (res < eps);
418 :
419 : // Begin iterations.
420 0 : int it = 0;
421 0 : if (print_opts.iterations)
422 : {
423 0 : Mpi::Print(comm, "{}Residual norms for PCG solve\n",
424 0 : std::string(tab_width + int_width - 1, ' '));
425 : }
426 0 : for (; it < max_it && !converged; it++)
427 : {
428 0 : if (print_opts.iterations)
429 : {
430 0 : Mpi::Print(comm, "{}{:{}d} KSP residual norm ||r||_B = {:.6e}\n",
431 0 : std::string(tab_width, ' '), it, int_width, res);
432 : }
433 0 : if (!it)
434 : {
435 0 : p = z;
436 : }
437 : else
438 : {
439 0 : linalg::AXPBY(ScalarType(1.0), z, beta / beta_prev, p);
440 : }
441 :
442 0 : A->Mult(p, z);
443 0 : denom = linalg::Dot(comm, z, p);
444 : CheckDot(denom, "PCG operator is not positive definite: (Ap, p) = ");
445 0 : alpha = beta / denom;
446 :
447 0 : x.Add(alpha, p);
448 0 : r.Add(-alpha, z);
449 :
450 : beta_prev = beta;
451 0 : if (B)
452 : {
453 0 : ApplyB(B, r, z, this->use_timer);
454 : }
455 : else
456 : {
457 0 : z = r;
458 : }
459 0 : beta = linalg::Dot(comm, z, r);
460 : CheckDot(beta, "PCG preconditioner is not positive definite: (Br, r) = ");
461 0 : res = std::sqrt(std::abs(beta));
462 0 : converged = (res < eps);
463 : }
464 0 : if (print_opts.iterations)
465 : {
466 0 : Mpi::Print(comm, "{}{:{}d} KSP residual norm ||r||_B = {:.6e}\n",
467 0 : std::string(tab_width, ' '), it, int_width, res);
468 : }
469 0 : if (print_opts.summary || (print_opts.warnings && eps > 0.0 && !converged))
470 : {
471 0 : Mpi::Print(comm, "{}PCG solver {} in {:d} iteration{}", std::string(tab_width, ' '),
472 0 : converged ? "converged" : "did NOT converge", it, (it == 1) ? "" : "s");
473 0 : if (it > 0)
474 : {
475 0 : Mpi::Print(comm, " (avg. reduction factor: {:.3e})\n",
476 0 : std::pow(res / initial_res, 1.0 / it));
477 : }
478 : else
479 : {
480 0 : Mpi::Print(comm, "\n");
481 : }
482 : }
483 0 : final_res = res;
484 0 : final_it = it;
485 0 : }
486 :
487 : template <typename OperType>
488 0 : void GmresSolver<OperType>::Initialize() const
489 : {
490 0 : if (!V.empty())
491 : {
492 : MFEM_ASSERT(V.size() == static_cast<std::size_t>(max_dim + 1) &&
493 : V[0].Size() == A->Height(),
494 : "Repeated solves with GmresSolver should not modify the operator size or "
495 : "restart dimension!");
496 0 : return;
497 : }
498 0 : if (max_dim < 0)
499 : {
500 0 : max_dim = max_it;
501 : }
502 0 : constexpr int init_size = 5;
503 0 : V.resize(max_dim + 1);
504 0 : for (int j = 0; j < std::min(init_size, max_dim + 1); j++)
505 : {
506 0 : V[j].SetSize(A->Height());
507 0 : V[j].UseDevice(true);
508 : }
509 0 : H.resize((max_dim + 1) * max_dim);
510 0 : s.resize(max_dim + 1);
511 0 : cs.resize(max_dim + 1);
512 0 : sn.resize(max_dim + 1);
513 : }
514 :
515 : template <typename OperType>
516 0 : void GmresSolver<OperType>::Update(int j) const
517 : {
518 : // Add storage for basis vectors in increments.
519 : constexpr int add_size = 10;
520 0 : for (int k = j + 1; k < std::min(j + 1 + add_size, max_dim + 1); k++)
521 : {
522 0 : V[k].SetSize(A->Height());
523 0 : V[k].UseDevice(true);
524 : }
525 0 : }
526 :
527 : template <typename OperType>
528 0 : void GmresSolver<OperType>::Mult(const VecType &b, VecType &x) const
529 : {
530 : // Set up workspace.
531 0 : RealType beta = 0.0, true_beta, eps = 0.0;
532 0 : MFEM_VERIFY(A, "Operator must be set for GmresSolver::Mult!");
533 : MFEM_ASSERT(A->Width() == x.Size() && A->Height() == b.Size(),
534 : "Size mismatch for GmresSolver::Mult!");
535 0 : r.SetSize(A->Height());
536 0 : r.UseDevice(true);
537 0 : Initialize();
538 :
539 : // Begin iterations.
540 0 : converged = false;
541 0 : int it = 0, restart = 0;
542 0 : if (print_opts.iterations)
543 : {
544 0 : Mpi::Print(comm, "{}Residual norms for GMRES solve\n",
545 0 : std::string(tab_width + int_width - 1, ' '));
546 : }
547 0 : for (; it < max_it; restart++)
548 : {
549 : // Initialize.
550 0 : InitialResidual(pc_side, A, B, b, x, r, V[0], (this->initial_guess || restart > 0),
551 0 : this->use_timer);
552 0 : true_beta = linalg::Norml2(comm, r);
553 : CheckDot(true_beta, "GMRES residual norm is not valid: beta = ");
554 0 : if (it == 0)
555 : {
556 0 : if (this->initial_guess)
557 : {
558 : RealType beta_rhs;
559 0 : if (B && pc_side == PreconditionerSide::LEFT)
560 : {
561 0 : ApplyB(B, b, V[0], this->use_timer);
562 0 : beta_rhs = linalg::Norml2(comm, V[0]);
563 : }
564 : else // !B || pc_side == PreconditionerSide::RIGHT
565 : {
566 0 : beta_rhs = linalg::Norml2(comm, b);
567 : }
568 : CheckDot(beta_rhs, "GMRES residual norm is not valid: beta_rhs = ");
569 0 : initial_res = beta_rhs;
570 : }
571 : else
572 : {
573 0 : initial_res = true_beta;
574 : }
575 0 : eps = std::max(rel_tol * initial_res, abs_tol);
576 : }
577 0 : else if (beta > 0.0 && std::abs(beta - true_beta) > 0.1 * true_beta &&
578 0 : print_opts.warnings)
579 : {
580 0 : Mpi::Print(
581 0 : comm,
582 : "{}GMRES residual at restart ({:.6e}) is far from the residual norm estimate "
583 : "from the recursion formula ({:.6e}) (initial residual = {:.6e})\n",
584 0 : std::string(tab_width, ' '), true_beta, beta, initial_res);
585 : }
586 0 : beta = true_beta;
587 0 : if (beta < eps)
588 : {
589 0 : converged = true;
590 0 : break;
591 : }
592 :
593 0 : V[0] = 0.0;
594 0 : V[0].Add(1.0 / beta, r);
595 : std::fill(s.begin(), s.end(), 0.0);
596 0 : s[0] = beta;
597 :
598 : int j = 0;
599 0 : for (;; j++, it++)
600 : {
601 0 : if (print_opts.iterations)
602 : {
603 0 : Mpi::Print(comm, "{}{:{}d} (restart {:d}) KSP residual norm {:.6e}\n",
604 0 : std::string(tab_width, ' '), it, int_width, restart, beta);
605 : }
606 0 : VecType &w = V[j + 1];
607 0 : if (w.Size() == 0)
608 : {
609 0 : Update(j);
610 : }
611 0 : ApplyBA(pc_side, A, B, V[j], w, r, this->use_timer);
612 :
613 0 : ScalarType *Hj = H.data() + j * (max_dim + 1);
614 0 : OrthogonalizeIteration(gs_orthog, comm, V, w, Hj, j);
615 0 : Hj[j + 1] = linalg::Norml2(comm, w);
616 0 : w *= 1.0 / Hj[j + 1];
617 :
618 0 : for (int k = 0; k < j; k++)
619 : {
620 0 : ApplyPlaneRotation(Hj[k], Hj[k + 1], cs[k], sn[k]);
621 : }
622 0 : GeneratePlaneRotation(Hj[j], Hj[j + 1], cs[j], sn[j]);
623 0 : ApplyPlaneRotation(Hj[j], Hj[j + 1], cs[j], sn[j]);
624 0 : ApplyPlaneRotation(s[j], s[j + 1], cs[j], sn[j]);
625 :
626 0 : beta = std::abs(s[j + 1]);
627 : CheckDot(beta, "GMRES residual norm is not valid: beta = ");
628 0 : converged = (beta < eps);
629 0 : if (converged || j + 1 == max_dim || it + 1 == max_it)
630 : {
631 0 : it++;
632 : break;
633 : }
634 : }
635 :
636 : // Reconstruct the solution (for restart or due to convergence or maximum iterations).
637 0 : for (int i = j; i >= 0; i--)
638 : {
639 0 : ScalarType *Hi = H.data() + i * (max_dim + 1);
640 0 : s[i] /= Hi[i];
641 0 : for (int k = i - 1; k >= 0; k--)
642 : {
643 0 : s[k] -= Hi[k] * s[i];
644 : }
645 : }
646 0 : if (!B || pc_side == PreconditionerSide::LEFT)
647 : {
648 0 : for (int k = 0; k <= j; k++)
649 : {
650 0 : x.Add(s[k], V[k]);
651 : }
652 : }
653 : else // B && pc_side == PreconditionerSide::RIGHT
654 : {
655 0 : r = 0.0;
656 0 : for (int k = 0; k <= j; k++)
657 : {
658 0 : r.Add(s[k], V[k]);
659 : }
660 0 : ApplyB(B, r, V[0], this->use_timer);
661 0 : x += V[0];
662 : }
663 0 : if (converged)
664 : {
665 : break;
666 : }
667 : }
668 0 : if (print_opts.iterations)
669 : {
670 0 : Mpi::Print(comm, "{}{:{}d} (restart {:d}) KSP residual norm {:.6e}\n",
671 0 : std::string(tab_width, ' '), it, int_width, restart, beta);
672 : }
673 0 : if (print_opts.summary || (print_opts.warnings && eps > 0.0 && !converged))
674 : {
675 0 : Mpi::Print(comm, "{}GMRES solver {} in {:d} iteration{}", std::string(tab_width, ' '),
676 0 : converged ? "converged" : "did NOT converge", it, (it == 1) ? "" : "s");
677 0 : if (it > 0)
678 : {
679 0 : Mpi::Print(comm, " (avg. reduction factor: {:.3e})\n",
680 0 : std::pow(beta / initial_res, 1.0 / it));
681 : }
682 : else
683 : {
684 0 : Mpi::Print(comm, "\n");
685 : }
686 : }
687 0 : final_res = beta;
688 0 : final_it = it;
689 0 : }
690 :
691 : template <typename OperType>
692 0 : void FgmresSolver<OperType>::Initialize() const
693 : {
694 0 : GmresSolver<OperType>::Initialize();
695 0 : constexpr int init_size = 5;
696 0 : Z.resize(max_dim + 1);
697 0 : for (int j = 0; j < std::min(init_size, max_dim + 1); j++)
698 : {
699 0 : Z[j].SetSize(A->Height());
700 0 : Z[j].UseDevice(true);
701 : }
702 0 : }
703 :
704 : template <typename OperType>
705 0 : void FgmresSolver<OperType>::Update(int j) const
706 : {
707 : // Add storage for basis vectors in increments.
708 0 : GmresSolver<OperType>::Update(j);
709 : constexpr int add_size = 10;
710 0 : for (int k = j + 1; k < std::min(j + 1 + add_size, max_dim + 1); k++)
711 : {
712 0 : Z[k].SetSize(A->Height());
713 0 : Z[k].UseDevice(true);
714 : }
715 0 : }
716 :
717 : template <typename OperType>
718 0 : void FgmresSolver<OperType>::Mult(const VecType &b, VecType &x) const
719 : {
720 : // Set up workspace.
721 0 : RealType beta = 0.0, true_beta, eps = 0.0;
722 0 : MFEM_VERIFY(A && B, "Operator and preconditioner must be set for FgmresSolver::Mult!");
723 : MFEM_ASSERT(A->Width() == x.Size() && A->Height() == b.Size(),
724 : "Size mismatch for FgmresSolver::Mult!");
725 0 : Initialize();
726 :
727 : // Begin iterations.
728 0 : converged = false;
729 0 : int it = 0, restart = 0;
730 0 : if (print_opts.iterations)
731 : {
732 0 : Mpi::Print(comm, "{}Residual norms for FGMRES solve\n",
733 0 : std::string(tab_width + int_width - 1, ' '));
734 : }
735 0 : for (; it < max_it; restart++)
736 : {
737 : // Initialize.
738 0 : InitialResidual(PreconditionerSide::RIGHT, A, B, b, x, Z[0], V[0],
739 0 : (this->initial_guess || restart > 0), this->use_timer);
740 0 : true_beta = linalg::Norml2(comm, Z[0]);
741 : CheckDot(true_beta, "FGMRES residual norm is not valid: beta = ");
742 0 : if (it == 0)
743 : {
744 0 : if (this->initial_guess)
745 : {
746 0 : auto beta_rhs = linalg::Norml2(comm, b);
747 : CheckDot(beta_rhs, "GMRES residual norm is not valid: beta_rhs = ");
748 0 : initial_res = beta_rhs;
749 : }
750 : else
751 : {
752 0 : initial_res = true_beta;
753 : }
754 0 : eps = std::max(rel_tol * initial_res, abs_tol);
755 : }
756 0 : else if (beta > 0.0 && std::abs(beta - true_beta) > 0.1 * true_beta &&
757 0 : print_opts.warnings)
758 : {
759 0 : Mpi::Print(
760 0 : comm,
761 : "{}FGMRES residual at restart ({:.6e}) is far from the residual norm estimate "
762 : "from the recursion formula ({:.6e}) (initial residual = {:.6e})\n",
763 0 : std::string(tab_width, ' '), true_beta, beta, initial_res);
764 : }
765 0 : beta = true_beta;
766 0 : if (beta < eps)
767 : {
768 0 : converged = true;
769 0 : break;
770 : }
771 :
772 0 : V[0] = 0.0;
773 0 : V[0].Add(1.0 / beta, Z[0]);
774 : std::fill(s.begin(), s.end(), 0.0);
775 0 : s[0] = beta;
776 :
777 : int j = 0;
778 0 : for (;; j++, it++)
779 : {
780 0 : if (print_opts.iterations)
781 : {
782 0 : Mpi::Print(comm, "{}{:{}d} (restart {:d}) KSP residual norm {:.6e}\n",
783 0 : std::string(tab_width, ' '), it, int_width, restart, beta);
784 : }
785 0 : VecType &w = V[j + 1];
786 0 : if (w.Size() == 0)
787 : {
788 0 : Update(j);
789 : }
790 0 : ApplyBA(PreconditionerSide::RIGHT, A, B, V[j], w, Z[j], this->use_timer);
791 :
792 0 : ScalarType *Hj = H.data() + j * (max_dim + 1);
793 0 : OrthogonalizeIteration(gs_orthog, comm, V, w, Hj, j);
794 0 : Hj[j + 1] = linalg::Norml2(comm, w);
795 0 : w *= 1.0 / Hj[j + 1];
796 :
797 0 : for (int k = 0; k < j; k++)
798 : {
799 0 : ApplyPlaneRotation(Hj[k], Hj[k + 1], cs[k], sn[k]);
800 : }
801 0 : GeneratePlaneRotation(Hj[j], Hj[j + 1], cs[j], sn[j]);
802 0 : ApplyPlaneRotation(Hj[j], Hj[j + 1], cs[j], sn[j]);
803 0 : ApplyPlaneRotation(s[j], s[j + 1], cs[j], sn[j]);
804 :
805 0 : beta = std::abs(s[j + 1]);
806 : CheckDot(beta, "FGMRES residual norm is not valid: beta = ");
807 0 : converged = (beta < eps);
808 0 : if (converged || j + 1 == max_dim || it + 1 == max_it)
809 : {
810 0 : it++;
811 : break;
812 : }
813 : }
814 :
815 : // Reconstruct the solution (for restart or due to convergence or maximum iterations).
816 0 : for (int i = j; i >= 0; i--)
817 : {
818 0 : ScalarType *Hi = H.data() + i * (max_dim + 1);
819 0 : s[i] /= Hi[i];
820 0 : for (int k = i - 1; k >= 0; k--)
821 : {
822 0 : s[k] -= Hi[k] * s[i];
823 : }
824 : }
825 0 : for (int k = 0; k <= j; k++)
826 : {
827 0 : x.Add(s[k], Z[k]);
828 : }
829 0 : if (converged)
830 : {
831 : break;
832 : }
833 : }
834 0 : if (print_opts.iterations)
835 : {
836 0 : Mpi::Print(comm, "{}{:{}d} (restart {:d}) KSP residual norm {:.6e}\n",
837 0 : std::string(tab_width, ' '), it, int_width, restart, beta);
838 : }
839 0 : if (print_opts.summary || (print_opts.warnings && eps > 0.0 && !converged))
840 : {
841 0 : Mpi::Print(comm, "{}FGMRES solver {} in {:d} iteration{}", std::string(tab_width, ' '),
842 0 : converged ? "converged" : "did NOT converge", it, (it == 1) ? "" : "s");
843 0 : if (it > 0)
844 : {
845 0 : Mpi::Print(comm, " (avg. reduction factor: {:.3e})\n",
846 0 : std::pow(beta / initial_res, 1.0 / it));
847 : }
848 : else
849 : {
850 0 : Mpi::Print(comm, "\n");
851 : }
852 : }
853 0 : final_res = beta;
854 0 : final_it = it;
855 0 : }
856 :
857 : template class IterativeSolver<Operator>;
858 : template class IterativeSolver<ComplexOperator>;
859 : template class CgSolver<Operator>;
860 : template class CgSolver<ComplexOperator>;
861 : template class GmresSolver<Operator>;
862 : template class GmresSolver<ComplexOperator>;
863 : template class FgmresSolver<Operator>;
864 : template class FgmresSolver<ComplexOperator>;
865 :
866 : } // namespace palace
|