Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #ifndef PALACE_LINALG_RAP_HPP
5 : #define PALACE_LINALG_RAP_HPP
6 :
7 : #include <array>
8 : #include <memory>
9 : #include <mfem.hpp>
10 : #include "fem/fespace.hpp"
11 : #include "linalg/operator.hpp"
12 : #include "linalg/vector.hpp"
13 :
14 : namespace palace
15 : {
16 :
17 : //
18 : // A parallel operator represented by RAP constructed through the actions of R, A, and P,
19 : // usually with R = Páµ€, and with possible eliminated essential BC. Here R and P are the
20 : // parallel restriction and prolongation matrices.
21 : //
22 :
23 : // Real-valued RAP operator.
24 : class ParOperator : public Operator
25 : {
26 : private:
27 : // Storage and access for the local operator.
28 : std::unique_ptr<Operator> data_A;
29 : const Operator *A;
30 :
31 : // Finite element spaces for parallel prolongation and restriction.
32 : const FiniteElementSpace &trial_fespace, &test_fespace;
33 : const bool use_R;
34 :
35 : // Lists of constrained essential boundary true dofs for elimination.
36 : mfem::Array<int> dbc_tdof_list;
37 :
38 : // Diagonal policy for constrained true dofs.
39 : DiagonalPolicy diag_policy = DiagonalPolicy::DIAG_ZERO;
40 :
41 : // Assembled operator as a parallel Hypre matrix. If assembled, the local operator is not
42 : // deleted.
43 : mutable std::unique_ptr<mfem::HypreParMatrix> RAP;
44 :
45 : // Helper methods for operator application.
46 : void RestrictionMatrixMult(const Vector &ly, Vector &ty) const;
47 : void RestrictionMatrixMultTranspose(const Vector &ty, Vector &ly) const;
48 : Vector &GetTestLVector() const;
49 :
50 : ParOperator(std::unique_ptr<Operator> &&dA, const Operator *pA,
51 : const FiniteElementSpace &trial_fespace,
52 : const FiniteElementSpace &test_fespace, bool test_restrict);
53 :
54 : public:
55 : // Construct the parallel operator, inheriting ownership of the local operator.
56 : ParOperator(std::unique_ptr<Operator> &&A, const FiniteElementSpace &trial_fespace,
57 : const FiniteElementSpace &test_fespace, bool test_restrict);
58 : ParOperator(std::unique_ptr<Operator> &&A, const FiniteElementSpace &fespace)
59 9 : : ParOperator(std::move(A), fespace, fespace, false)
60 : {
61 9 : }
62 :
63 : // Non-owning constructors.
64 : ParOperator(const Operator &A, const FiniteElementSpace &trial_fespace,
65 : const FiniteElementSpace &test_fespace, bool test_restrict);
66 : ParOperator(const Operator &A, const FiniteElementSpace &fespace)
67 : : ParOperator(A, fespace, fespace, false)
68 : {
69 : }
70 :
71 : // Get access to the underlying local (L-vector) operator.
72 12 : const Operator &LocalOperator() const { return *A; }
73 :
74 : // Get the associated MPI communicator.
75 : MPI_Comm GetComm() const { return trial_fespace.GetComm(); }
76 :
77 : // Accessor for trial finite element space.
78 9 : const FiniteElementSpace &TrialFiniteElementSpace() const { return trial_fespace; }
79 :
80 : // Accessor for test finite element space.
81 : const FiniteElementSpace &TestFiniteElementSpace() const { return test_fespace; }
82 :
83 : // Set essential boundary condition true dofs for square operators.
84 : void SetEssentialTrueDofs(const mfem::Array<int> &tdof_list, DiagonalPolicy policy);
85 :
86 : // Get the essential boundary condition true dofs associated with the operator. May be
87 : // nullptr.
88 : const mfem::Array<int> *GetEssentialTrueDofs() const
89 : {
90 6 : return dbc_tdof_list.Size() ? &dbc_tdof_list : nullptr;
91 : }
92 :
93 : // Get the diagonal policy that was most recently used. If there are no essential dofs,
94 : // and thus no valid policy, will error.
95 : DiagonalPolicy GetDiagonalPolicy() const;
96 :
97 : // Eliminate essential true dofs from the RHS vector b, using the essential boundary
98 : // condition values in x.
99 : void EliminateRHS(const Vector &x, Vector &b) const;
100 :
101 : // Assemble the operator as a parallel sparse matrix. The memory associated with the
102 : // local operator is free'd.
103 : mfem::HypreParMatrix &ParallelAssemble(bool skip_zeros = false) const;
104 :
105 : // Steal the assembled parallel sparse matrix.
106 : std::unique_ptr<mfem::HypreParMatrix> StealParallelAssemble(bool skip_zeros = false) const
107 : {
108 0 : ParallelAssemble(skip_zeros);
109 : return std::move(RAP);
110 : }
111 :
112 : void AssembleDiagonal(Vector &diag) const override;
113 :
114 : void Mult(const Vector &x, Vector &y) const override;
115 :
116 : void MultTranspose(const Vector &x, Vector &y) const override;
117 :
118 : void AddMult(const Vector &x, Vector &y, const double a = 1.0) const override;
119 :
120 : void AddMultTranspose(const Vector &x, Vector &y, const double a = 1.0) const override;
121 : };
122 :
123 : // Complex-valued RAP operator.
124 : class ComplexParOperator : public ComplexOperator
125 : {
126 : private:
127 : // Storage and access for the local operator.
128 : std::unique_ptr<ComplexWrapperOperator> data_A;
129 : const ComplexWrapperOperator *A;
130 :
131 : // Finite element spaces for parallel prolongation and restriction.
132 : const FiniteElementSpace &trial_fespace, &test_fespace;
133 : const bool use_R;
134 :
135 : // Lists of constrained essential boundary true dofs for elimination.
136 : mfem::Array<int> dbc_tdof_list;
137 :
138 : // Diagonal policy for constrained true dofs.
139 : Operator::DiagonalPolicy diag_policy = Operator::DiagonalPolicy::DIAG_ZERO;
140 :
141 : // Real and imaginary parts of the operator as non-owning ParOperator objects.
142 : std::unique_ptr<ParOperator> RAPr, RAPi;
143 :
144 : // Helper methods for operator application.
145 : void RestrictionMatrixMult(const ComplexVector &ly, ComplexVector &ty) const;
146 : void RestrictionMatrixMultTranspose(const ComplexVector &ty, ComplexVector &ly) const;
147 : ComplexVector &GetTestLVector() const;
148 :
149 : ComplexParOperator(std::unique_ptr<Operator> &&dAr, std::unique_ptr<Operator> &&dAi,
150 : const Operator *pAr, const Operator *pAi,
151 : const FiniteElementSpace &trial_fespace,
152 : const FiniteElementSpace &test_fespace, bool test_restrict);
153 :
154 : public:
155 : // Construct the complex-valued parallel operator from the separate real and imaginary
156 : // parts, inheriting ownership of the local operator.
157 : ComplexParOperator(std::unique_ptr<Operator> &&Ar, std::unique_ptr<Operator> &&Ai,
158 : const FiniteElementSpace &trial_fespace,
159 : const FiniteElementSpace &test_fespace, bool test_restrict);
160 : ComplexParOperator(std::unique_ptr<Operator> &&Ar, std::unique_ptr<Operator> &&Ai,
161 : const FiniteElementSpace &fespace)
162 9 : : ComplexParOperator(std::move(Ar), std::move(Ai), fespace, fespace, false)
163 : {
164 9 : }
165 :
166 : // Non-owning constructors.
167 : ComplexParOperator(const Operator *Ar, const Operator *Ai,
168 : const FiniteElementSpace &trial_fespace,
169 : const FiniteElementSpace &test_fespace, bool test_restrict);
170 : ComplexParOperator(const Operator *Ar, const Operator *Ai,
171 : const FiniteElementSpace &fespace)
172 : : ComplexParOperator(Ar, Ai, fespace, fespace, false)
173 : {
174 : }
175 :
176 0 : const Operator *Real() const override { return RAPr.get(); }
177 0 : const Operator *Imag() const override { return RAPi.get(); }
178 :
179 : // Get access to the underlying local (L-vector) operator.
180 51 : const ComplexOperator &LocalOperator() const { return *A; }
181 :
182 : // Get the associated MPI communicator.
183 : MPI_Comm GetComm() const { return trial_fespace.GetComm(); }
184 :
185 : // Accessor for trial finite element space.
186 9 : const FiniteElementSpace &TrialFiniteElementSpace() const { return trial_fespace; }
187 :
188 : // Accessor for test finite element space.
189 : const FiniteElementSpace &TestFiniteElementSpace() const { return test_fespace; }
190 :
191 : // Set essential boundary condition true dofs for square operators.
192 : void SetEssentialTrueDofs(const mfem::Array<int> &tdof_list,
193 : Operator::DiagonalPolicy policy);
194 :
195 : // Get the essential boundary condition true dofs associated with the operator. May be
196 : // nullptr.
197 : const mfem::Array<int> *GetEssentialTrueDofs() const
198 : {
199 6 : return dbc_tdof_list.Size() ? &dbc_tdof_list : nullptr;
200 : }
201 :
202 : // Get the diagonal policy that was most recently used. If there are no essential dofs,
203 : // and thus no valid policy, will error.
204 : Operator::DiagonalPolicy GetDiagonalPolicy() const;
205 :
206 : void AssembleDiagonal(ComplexVector &diag) const override;
207 :
208 : void Mult(const ComplexVector &x, ComplexVector &y) const override;
209 :
210 : void MultTranspose(const ComplexVector &x, ComplexVector &y) const override;
211 :
212 : void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override;
213 :
214 : void AddMult(const ComplexVector &x, ComplexVector &y,
215 : const std::complex<double> a = 1.0) const override;
216 :
217 : void AddMultTranspose(const ComplexVector &x, ComplexVector &y,
218 : const std::complex<double> a = 1.0) const override;
219 :
220 : void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
221 : const std::complex<double> a = 1.0) const override;
222 : };
223 :
224 : // Combine a collection of ParOperator into a weighted summation. If set_essential is true,
225 : // extract the essential dofs from the operator array, and apply to the summed operator.
226 : // Requires explicit instantiation.
227 : template <std::size_t N>
228 : std::unique_ptr<ParOperator>
229 : BuildParSumOperator(const std::array<double, N> &coeff,
230 : const std::array<const ParOperator *, N> &ops,
231 : bool set_essential = true);
232 :
233 : // Combine a collection of ComplexParOperator into a weighted summation. If set_essential is
234 : // true, extract the essential dofs from the operator array, and apply to the summed
235 : // operator. Requires explicit instantiation.
236 : template <std::size_t N>
237 : std::unique_ptr<ComplexParOperator>
238 : BuildParSumOperator(const std::array<std::complex<double>, N> &coeff,
239 : const std::array<const ComplexParOperator *, N> &ops,
240 : bool set_essential = true);
241 :
242 : // Dispatcher to convert initializer list or C arrays into std::array whilst deducing sizes
243 : // and types.
244 : template <std::size_t N, typename ScalarType, typename OperType>
245 : std::unique_ptr<std::conditional_t<std::is_base_of_v<ComplexOperator, OperType>,
246 : ComplexParOperator, ParOperator>>
247 : BuildParSumOperator(ScalarType (&&coeff_in)[N], const OperType *(&&ops_in)[N],
248 : bool set_essential = true);
249 :
250 : } // namespace palace
251 :
252 : #endif // PALACE_LINALG_RAP_HPP
|