Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #ifndef PALACE_LINALG_SOLVER_HPP
5 : #define PALACE_LINALG_SOLVER_HPP
6 :
7 : #include <type_traits>
8 : #include <mfem.hpp>
9 : #include "linalg/operator.hpp"
10 : #include "linalg/vector.hpp"
11 :
12 : namespace palace
13 : {
14 :
15 : //
16 : // The base Solver<OperType> class is a templated version of mfem::Solver for operation with
17 : // real- or complex-valued operators.
18 : //
19 :
20 : // Abstract base class for real-valued or complex-valued solvers.
21 : template <typename OperType>
22 : class Solver : public OperType
23 : {
24 : static_assert(std::is_same<OperType, Operator>::value ||
25 : std::is_same<OperType, ComplexOperator>::value,
26 : "Solver can only be defined for OperType = Operator or ComplexOperator!");
27 :
28 : protected:
29 : using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
30 : ComplexVector, Vector>::type;
31 :
32 : // Whether or not to use the second argument of Mult() as an initial guess.
33 : bool initial_guess;
34 :
35 : public:
36 0 : Solver(bool initial_guess = false) : OperType(), initial_guess(initial_guess) {}
37 : virtual ~Solver() = default;
38 :
39 : // Configure whether or not to use an initial guess when applying the solver.
40 0 : virtual void SetInitialGuess(bool guess) { initial_guess = guess; }
41 :
42 : // Set the operator associated with the solver, or update it if called repeatedly.
43 : virtual void SetOperator(const OperType &op) = 0;
44 :
45 : // Apply the solver for the transpose problem.
46 0 : void MultTranspose(const VecType &x, VecType &y) const override
47 : {
48 0 : MFEM_ABORT("MultTranspose() is not implemented for base class Solver<OperType>!");
49 : }
50 :
51 : // Apply the solver with a preallocated temporary storage vector.
52 0 : virtual void Mult2(const VecType &x, VecType &y, VecType &r) const
53 : {
54 0 : MFEM_ABORT("Mult2() with temporary storage vector is not implemented for base class "
55 : "Solver<OperType>!");
56 : }
57 :
58 : // Apply the solver for the transpose problem with a preallocated temporary storage
59 : // vector.
60 0 : virtual void MultTranspose2(const VecType &x, VecType &y, VecType &r) const
61 : {
62 0 : MFEM_ABORT("MultTranspose2() with temporary storage vector is not implemented for base "
63 : "class Solver<OperType>!");
64 : }
65 : };
66 :
67 : // This solver wraps a real-valued mfem::Solver for application to complex-valued problems
68 : // as a preconditioner inside of a Solver<OperType> or for assembling the matrix-free
69 : // preconditioner operator as an mfem::HypreParMatrix.
70 : template <typename OperType>
71 : class MfemWrapperSolver : public Solver<OperType>
72 : {
73 : using VecType = typename Solver<OperType>::VecType;
74 :
75 : private:
76 : // The actual mfem::Solver.
77 : std::unique_ptr<mfem::Solver> pc;
78 :
79 : // System matrix A in parallel assembled form.
80 : std::unique_ptr<mfem::HypreParMatrix> A;
81 :
82 : // Whether or not to save the parallel assembled matrix after calling
83 : // mfem::Solver::SetOperator (some solvers copy their input).
84 : bool save_assembled;
85 :
86 : // Whether to use the exact complex-valued system matrix or the real-valued
87 : // approximation A = Ar + Ai.
88 : bool complex_matrix = true;
89 :
90 : public:
91 0 : MfemWrapperSolver(std::unique_ptr<mfem::Solver> &&pc, bool save_assembled = true,
92 : bool complex_matrix = true)
93 0 : : Solver<OperType>(pc->iterative_mode), pc(std::move(pc)),
94 0 : save_assembled(save_assembled), complex_matrix(complex_matrix)
95 : {
96 : }
97 :
98 : // Access the underlying solver.
99 : const mfem::Solver &GetSolver() { return *pc; }
100 :
101 : // Configure whether or not to save the assembled operator.
102 0 : void SetSaveAssembled(bool save) { save_assembled = save; }
103 :
104 0 : void SetInitialGuess(bool guess) override
105 : {
106 : Solver<OperType>::SetInitialGuess(guess);
107 0 : pc->iterative_mode = guess;
108 0 : }
109 :
110 : void SetOperator(const OperType &op) override;
111 :
112 : void Mult(const VecType &x, VecType &y) const override;
113 : };
114 :
115 : } // namespace palace
116 :
117 : #endif // PALACE_LINALG_SOLVER_HPP
|