Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #ifndef PALACE_LINALG_OPERATOR_HPP
5 : #define PALACE_LINALG_OPERATOR_HPP
6 :
7 : #include <complex>
8 : #include <memory>
9 : #include <type_traits>
10 : #include <utility>
11 : #include <vector>
12 : #include "linalg/vector.hpp"
13 :
14 : namespace palace
15 : {
16 :
17 : //
18 : // Functionality extending mfem::Operator from MFEM.
19 : //
20 :
21 : using Operator = mfem::Operator;
22 :
23 : // Abstract base class for complex-valued operators.
24 : class ComplexOperator
25 : {
26 : protected:
27 : // The size of the complex-valued operator.
28 : int height, width;
29 :
30 : public:
31 9 : ComplexOperator(int s = 0) : height(s), width(s) {}
32 9 : ComplexOperator(int h, int w) : height(h), width(w) {}
33 : virtual ~ComplexOperator() = default;
34 :
35 : // Get the height (size of output) of the operator.
36 6 : int Height() const { return height; }
37 :
38 : // Get the width (size of input) of the operator.
39 3 : int Width() const { return width; }
40 :
41 : // Test whether or not the operator is purely real or imaginary.
42 0 : virtual bool IsReal() const { return !Imag(); }
43 0 : virtual bool IsImag() const { return !Real(); }
44 :
45 : // Get access to the real and imaginary operator parts separately (may be empty if
46 : // operator is purely real or imaginary).
47 : virtual const Operator *Real() const;
48 : virtual const Operator *Imag() const;
49 :
50 : // Diagonal assembly.
51 : virtual void AssembleDiagonal(ComplexVector &diag) const;
52 :
53 : // Operator application.
54 : virtual void Mult(const ComplexVector &x, ComplexVector &y) const = 0;
55 :
56 : virtual void MultTranspose(const ComplexVector &x, ComplexVector &y) const;
57 :
58 : virtual void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const;
59 :
60 : virtual void AddMult(const ComplexVector &x, ComplexVector &y,
61 : const std::complex<double> a = 1.0) const;
62 :
63 : virtual void AddMultTranspose(const ComplexVector &x, ComplexVector &y,
64 : const std::complex<double> a = 1.0) const;
65 :
66 : virtual void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
67 : const std::complex<double> a = 1.0) const;
68 : };
69 :
70 : // A complex-valued operator represented using a block 2 x 2 equivalent-real formulation:
71 : // [ yr ] = [ Ar -Ai ] [ xr ]
72 : // [ yi ] [ Ai Ar ] [ xi ] .
73 : class ComplexWrapperOperator : public ComplexOperator
74 : {
75 : private:
76 : // Storage and access for real and imaginary parts of the operator.
77 : std::unique_ptr<Operator> data_Ar, data_Ai;
78 : const Operator *Ar, *Ai;
79 :
80 : // Temporary storage for operator application.
81 : mutable ComplexVector tx, ty;
82 :
83 : ComplexWrapperOperator(std::unique_ptr<Operator> &&dAr, std::unique_ptr<Operator> &&dAi,
84 : const Operator *pAr, const Operator *pAi);
85 :
86 : public:
87 : // Construct a complex operator which inherits ownership of the input real and imaginary
88 : // parts.
89 : ComplexWrapperOperator(std::unique_ptr<Operator> &&Ar, std::unique_ptr<Operator> &&Ai);
90 :
91 : // Non-owning constructor.
92 : ComplexWrapperOperator(const Operator *Ar, const Operator *Ai);
93 :
94 42 : const Operator *Real() const override { return Ar; }
95 42 : const Operator *Imag() const override { return Ai; }
96 :
97 : void AssembleDiagonal(ComplexVector &diag) const override;
98 :
99 : void Mult(const ComplexVector &x, ComplexVector &y) const override;
100 :
101 : void MultTranspose(const ComplexVector &x, ComplexVector &y) const override;
102 :
103 : void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override;
104 :
105 : void AddMult(const ComplexVector &x, ComplexVector &y,
106 : const std::complex<double> a = 1.0) const override;
107 :
108 : void AddMultTranspose(const ComplexVector &x, ComplexVector &y,
109 : const std::complex<double> a = 1.0) const override;
110 :
111 : void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
112 : const std::complex<double> a = 1.0) const override;
113 : };
114 :
115 : // Wrap a sequence of operators of the same dimensions and optional coefficients.
116 : class SumOperator : public Operator
117 : {
118 : private:
119 : std::vector<std::pair<const Operator *, double>> ops;
120 : mutable Vector z;
121 :
122 : public:
123 : SumOperator(int s) : Operator(s) { z.UseDevice(true); }
124 9 : SumOperator(int h, int w) : Operator(h, w) { z.UseDevice(true); }
125 : SumOperator(const Operator &op, double a = 1.0);
126 :
127 : void AddOperator(const Operator &op, double a = 1.0);
128 :
129 : void Mult(const Vector &x, Vector &y) const override;
130 :
131 : void MultTranspose(const Vector &x, Vector &y) const override;
132 :
133 : void AddMult(const Vector &x, Vector &y, const double a = 1.0) const override;
134 :
135 : void AddMultTranspose(const Vector &x, Vector &y, const double a = 1.0) const override;
136 : };
137 :
138 : // Wraps two operators such that: (AB)ᵀ = BᵀAᵀ and, for complex symmetric operators, the
139 : // Hermitian transpose operation is (AB)ᴴ = BᴴAᴴ.
140 : template <typename ProductOperator, typename OperType>
141 : class ProductOperatorHelper : public OperType
142 : {
143 : };
144 :
145 : template <typename ProductOperator>
146 : class ProductOperatorHelper<ProductOperator, Operator> : public Operator
147 : {
148 : public:
149 : ProductOperatorHelper(int h, int w) : Operator(h, w) {}
150 : };
151 :
152 : template <typename ProductOperator>
153 : class ProductOperatorHelper<ProductOperator, ComplexOperator> : public ComplexOperator
154 : {
155 : public:
156 : ProductOperatorHelper(int h, int w) : ComplexOperator(h, w) {}
157 :
158 0 : void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override
159 : {
160 0 : const ComplexOperator &A = static_cast<const ProductOperator *>(this)->A;
161 0 : const ComplexOperator &B = static_cast<const ProductOperator *>(this)->B;
162 0 : ComplexVector &z = static_cast<const ProductOperator *>(this)->z;
163 0 : A.MultHermitianTranspose(x, z);
164 0 : B.MultHermitianTranspose(z, y);
165 0 : }
166 :
167 0 : void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
168 : const std::complex<double> a = 1.0) const override
169 : {
170 0 : const ComplexOperator &A = static_cast<const ProductOperator *>(this)->A;
171 0 : const ComplexOperator &B = static_cast<const ProductOperator *>(this)->B;
172 0 : ComplexVector &z = static_cast<const ProductOperator *>(this)->z;
173 0 : A.MultHermitianTranspose(x, z);
174 0 : B.AddMultHermitianTranspose(z, y, a);
175 0 : }
176 : };
177 :
178 : template <typename OperType>
179 0 : class BaseProductOperator
180 : : public ProductOperatorHelper<BaseProductOperator<OperType>, OperType>
181 : {
182 : friend class ProductOperatorHelper<BaseProductOperator<OperType>, OperType>;
183 :
184 : using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
185 : ComplexVector, Vector>::type;
186 : using ScalarType =
187 : typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
188 : std::complex<double>, double>::type;
189 :
190 : private:
191 : const OperType &A, &B;
192 : mutable VecType z;
193 :
194 : public:
195 0 : BaseProductOperator(const OperType &A, const OperType &B)
196 : : ProductOperatorHelper<BaseProductOperator<OperType>, OperType>(A.Height(), B.Width()),
197 0 : A(A), B(B), z(B.Height())
198 : {
199 0 : z.UseDevice(true);
200 0 : }
201 :
202 0 : void Mult(const VecType &x, VecType &y) const override
203 : {
204 0 : B.Mult(x, z);
205 0 : A.Mult(z, y);
206 0 : }
207 :
208 0 : void MultTranspose(const VecType &x, VecType &y) const override
209 : {
210 0 : A.MultTranspose(x, z);
211 0 : B.MultTranspose(z, y);
212 0 : }
213 :
214 0 : void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override
215 : {
216 0 : B.Mult(x, z);
217 0 : A.AddMult(z, y, a);
218 0 : }
219 :
220 0 : void AddMultTranspose(const VecType &x, VecType &y,
221 : const ScalarType a = 1.0) const override
222 : {
223 0 : A.MultTranspose(x, z);
224 0 : B.AddMultTranspose(z, y, a);
225 0 : }
226 : };
227 :
228 : using ProductOperator = BaseProductOperator<Operator>;
229 : using ComplexProductOperator = BaseProductOperator<ComplexOperator>;
230 :
231 : // Applies the simple, symmetric but not necessarily Hermitian, operator: diag(d).
232 : template <typename DiagonalOperator, typename OperType>
233 : class DiagonalOperatorHelper : public OperType
234 : {
235 : };
236 :
237 : template <typename DiagonalOperator>
238 : class DiagonalOperatorHelper<DiagonalOperator, Operator> : public Operator
239 : {
240 : public:
241 : DiagonalOperatorHelper(int s) : Operator(s) {}
242 : };
243 :
244 : template <typename DiagonalOperator>
245 : class DiagonalOperatorHelper<DiagonalOperator, ComplexOperator> : public ComplexOperator
246 : {
247 : public:
248 : DiagonalOperatorHelper(int s) : ComplexOperator(s) {}
249 :
250 : void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override;
251 :
252 : void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
253 : const std::complex<double> a = 1.0) const override;
254 : };
255 :
256 : template <typename OperType>
257 0 : class BaseDiagonalOperator
258 : : public DiagonalOperatorHelper<BaseDiagonalOperator<OperType>, OperType>
259 : {
260 : friend class DiagonalOperatorHelper<BaseDiagonalOperator<OperType>, OperType>;
261 :
262 : using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
263 : ComplexVector, Vector>::type;
264 : using ScalarType =
265 : typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
266 : std::complex<double>, double>::type;
267 :
268 : private:
269 : const VecType &d;
270 :
271 : public:
272 0 : BaseDiagonalOperator(const VecType &d)
273 0 : : DiagonalOperatorHelper<BaseDiagonalOperator<OperType>, OperType>(d.Size()), d(d)
274 : {
275 : }
276 :
277 : void Mult(const VecType &x, VecType &y) const override;
278 :
279 0 : void MultTranspose(const VecType &x, VecType &y) const override { Mult(x, y); }
280 :
281 : void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override;
282 :
283 0 : void AddMultTranspose(const VecType &x, VecType &y,
284 : const ScalarType a = 1.0) const override
285 : {
286 0 : AddMult(x, y, a);
287 0 : }
288 : };
289 :
290 : using DiagonalOperator = BaseDiagonalOperator<Operator>;
291 : using ComplexDiagonalOperator = BaseDiagonalOperator<ComplexOperator>;
292 :
293 : // A container for a sequence of operators corresponding to a multigrid hierarchy.
294 : // Optionally includes operators for the auxiliary space at each level as well. The
295 : // Operators are stored from coarsest to finest level. The height and width of this operator
296 : // are never set.
297 : template <typename OperType>
298 : class BaseMultigridOperator : public OperType
299 : {
300 : using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
301 : ComplexVector, Vector>::type;
302 : using ScalarType =
303 : typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
304 : std::complex<double>, double>::type;
305 :
306 : private:
307 : std::vector<std::unique_ptr<OperType>> ops, aux_ops;
308 :
309 : public:
310 0 : BaseMultigridOperator(std::size_t l) : OperType(0)
311 : {
312 0 : ops.reserve(l);
313 0 : aux_ops.reserve(l);
314 0 : }
315 :
316 : void AddOperator(std::unique_ptr<OperType> &&op)
317 : {
318 0 : ops.push_back(std::move(op));
319 0 : this->height = ops.back()->Height();
320 0 : this->width = ops.back()->Width();
321 : }
322 :
323 : void AddAuxiliaryOperator(std::unique_ptr<OperType> &&aux_op)
324 : {
325 0 : aux_ops.push_back(std::move(aux_op));
326 0 : }
327 :
328 : bool HasAuxiliaryOperators() const { return !aux_ops.empty(); }
329 : auto GetNumLevels() const { return ops.size(); }
330 : auto GetNumAuxiliaryLevels() const { return aux_ops.size(); }
331 :
332 : const OperType &GetFinestOperator() const { return *ops.back(); }
333 : const OperType &GetFinestAuxiliaryOperator() const { return *aux_ops.back(); }
334 :
335 : const OperType &GetOperatorAtLevel(std::size_t l) const
336 : {
337 : MFEM_ASSERT(l < GetNumLevels(), "Out of bounds multigrid level operator requested!");
338 : return *ops[l];
339 : }
340 : const OperType &GetAuxiliaryOperatorAtLevel(std::size_t l) const
341 : {
342 : MFEM_ASSERT(l < GetNumAuxiliaryLevels(),
343 : "Out of bounds multigrid level auxiliary operator requested!");
344 : return *aux_ops[l];
345 : }
346 :
347 0 : void Mult(const VecType &x, VecType &y) const override { GetFinestOperator().Mult(x, y); }
348 :
349 0 : void MultTranspose(const VecType &x, VecType &y) const override
350 : {
351 0 : GetFinestOperator().MultTranspose(x, y);
352 0 : }
353 :
354 0 : void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override
355 : {
356 0 : GetFinestOperator().AddMult(x, y, a);
357 0 : }
358 :
359 0 : void AddMultTranspose(const VecType &x, VecType &y,
360 : const ScalarType a = 1.0) const override
361 : {
362 0 : GetFinestOperator().AddMultTranspose(x, y, a);
363 0 : }
364 : };
365 :
366 : using MultigridOperator = BaseMultigridOperator<Operator>;
367 : using ComplexMultigridOperator = BaseMultigridOperator<ComplexOperator>;
368 :
369 : namespace linalg
370 : {
371 :
372 : // Calculate the vector norm with respect to an SPD matrix B.
373 : template <typename VecType>
374 : double Norml2(MPI_Comm comm, const VecType &x, const Operator &B, VecType &Bx);
375 :
376 : // Normalize the vector with respect to an SPD matrix B.
377 : template <typename VecType>
378 : inline double Normalize(MPI_Comm comm, VecType &x, const Operator &B, VecType &Bx)
379 : {
380 : double norm = Norml2(comm, x, B, Bx);
381 : MFEM_ASSERT(norm > 0.0, "Zero vector norm in normalization!");
382 : x *= 1.0 / norm;
383 : return norm;
384 : }
385 :
386 : // Estimate operator 2-norm (spectral norm) using power iteration. Assumes the operator is
387 : // not symmetric or Hermitian unless specified.
388 : double SpectralNorm(MPI_Comm comm, const Operator &A, bool sym = false, double tol = 1.0e-4,
389 : int max_it = 1000);
390 : double SpectralNorm(MPI_Comm comm, const ComplexOperator &A, bool herm = false,
391 : double tol = 1.0e-4, int max_it = 1000);
392 :
393 : } // namespace linalg
394 :
395 : } // namespace palace
396 :
397 : #endif // PALACE_LINALG_OPERATOR_HPP
|