Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #include "rap.hpp"
5 :
6 : #include "fem/bilinearform.hpp"
7 : #include "linalg/hypre.hpp"
8 :
9 : namespace palace
10 : {
11 :
12 33 : ParOperator::ParOperator(std::unique_ptr<Operator> &&dA, const Operator *pA,
13 : const FiniteElementSpace &trial_fespace,
14 33 : const FiniteElementSpace &test_fespace, bool test_restrict)
15 : : Operator(test_fespace.GetTrueVSize(), trial_fespace.GetTrueVSize()),
16 33 : data_A(std::move(dA)), A((data_A != nullptr) ? data_A.get() : pA),
17 33 : trial_fespace(trial_fespace), test_fespace(test_fespace), use_R(test_restrict),
18 48 : diag_policy(DiagonalPolicy::DIAG_ONE), RAP(nullptr)
19 : {
20 33 : MFEM_VERIFY(A, "Cannot construct ParOperator from an empty matrix!");
21 33 : }
22 :
23 15 : ParOperator::ParOperator(std::unique_ptr<Operator> &&A,
24 : const FiniteElementSpace &trial_fespace,
25 15 : const FiniteElementSpace &test_fespace, bool test_restrict)
26 15 : : ParOperator(std::move(A), nullptr, trial_fespace, test_fespace, test_restrict)
27 : {
28 15 : }
29 :
30 18 : ParOperator::ParOperator(const Operator &A, const FiniteElementSpace &trial_fespace,
31 18 : const FiniteElementSpace &test_fespace, bool test_restrict)
32 18 : : ParOperator(nullptr, &A, trial_fespace, test_fespace, test_restrict)
33 : {
34 18 : }
35 :
36 0 : void ParOperator::SetEssentialTrueDofs(const mfem::Array<int> &tdof_list,
37 : DiagonalPolicy policy)
38 : {
39 0 : MFEM_VERIFY(policy == DiagonalPolicy::DIAG_ONE || policy == DiagonalPolicy::DIAG_ZERO,
40 : "Essential boundary condition true dof elimination for ParOperator supports "
41 : "only DiagonalPolicy::DIAG_ONE or DiagonalPolicy::DIAG_ZERO!");
42 0 : MFEM_VERIFY(height == width, "Set essential true dofs for both test and trial spaces "
43 : "for rectangular ParOperator!");
44 : tdof_list.Read();
45 0 : dbc_tdof_list.MakeRef(tdof_list);
46 0 : diag_policy = policy;
47 0 : }
48 :
49 0 : Operator::DiagonalPolicy ParOperator::GetDiagonalPolicy() const
50 : {
51 0 : MFEM_VERIFY(dbc_tdof_list.Size() > 0,
52 : "There is no DiagonalPolicy if no essential dofs have been set!");
53 0 : return diag_policy;
54 : }
55 :
56 0 : void ParOperator::EliminateRHS(const Vector &x, Vector &b) const
57 : {
58 0 : MFEM_VERIFY(A, "No local matrix available for ParOperator::EliminateRHS!");
59 0 : auto &lx = trial_fespace.GetLVector<Vector>();
60 0 : auto &ly = GetTestLVector();
61 : {
62 0 : auto &tx = trial_fespace.GetTVector<Vector>();
63 0 : tx = 0.0;
64 0 : linalg::SetSubVector(tx, dbc_tdof_list, x);
65 0 : trial_fespace.GetProlongationMatrix()->Mult(tx, lx);
66 : }
67 :
68 : // Apply the unconstrained operator.
69 0 : A->Mult(lx, ly);
70 :
71 0 : auto &ty = test_fespace.GetTVector<Vector>();
72 0 : RestrictionMatrixMult(ly, ty);
73 0 : b.Add(-1.0, ty);
74 0 : if (diag_policy == DiagonalPolicy::DIAG_ONE)
75 : {
76 0 : linalg::SetSubVector(b, dbc_tdof_list, x);
77 : }
78 0 : else if (diag_policy == DiagonalPolicy::DIAG_ZERO)
79 : {
80 0 : linalg::SetSubVector(b, dbc_tdof_list, 0.0);
81 : }
82 0 : }
83 :
84 0 : mfem::HypreParMatrix &ParOperator::ParallelAssemble(bool skip_zeros) const
85 : {
86 0 : if (RAP)
87 : {
88 : return *RAP;
89 : }
90 :
91 : // Build the square or rectangular assembled HypreParMatrix.
92 0 : const auto *sA = dynamic_cast<const hypre::HypreCSRMatrix *>(A);
93 : std::unique_ptr<hypre::HypreCSRMatrix> data_sA;
94 0 : if (!sA)
95 : {
96 0 : const auto *cA = dynamic_cast<const ceed::Operator *>(A);
97 0 : MFEM_VERIFY(cA,
98 : "ParOperator::ParallelAssemble requires A as an hypre::HypreCSRMatrix or "
99 : "ceed::Operator!");
100 0 : data_sA = BilinearForm::FullAssemble(*cA, skip_zeros, use_R);
101 : sA = data_sA.get();
102 : }
103 :
104 0 : hypre_ParCSRMatrix *hA = hypre_ParCSRMatrixCreate(
105 0 : trial_fespace.GetComm(), test_fespace.GlobalVSize(), trial_fespace.GlobalVSize(),
106 0 : test_fespace.Get().GetDofOffsets(), trial_fespace.Get().GetDofOffsets(), 0, sA->NNZ(),
107 : 0);
108 0 : hypre_CSRMatrix *hA_diag = hypre_ParCSRMatrixDiag(hA);
109 0 : hypre_ParCSRMatrixDiag(hA) = *const_cast<hypre::HypreCSRMatrix *>(sA);
110 0 : hypre_ParCSRMatrixInitialize(hA);
111 :
112 0 : const mfem::HypreParMatrix *P = trial_fespace.Get().Dof_TrueDof_Matrix();
113 0 : if (!use_R)
114 : {
115 0 : const mfem::HypreParMatrix *Rt = test_fespace.Get().Dof_TrueDof_Matrix();
116 0 : RAP = std::make_unique<mfem::HypreParMatrix>(hypre_ParCSRMatrixRAPKT(*Rt, hA, *P, 1),
117 0 : true);
118 : }
119 : else
120 : {
121 : mfem::HypreParMatrix *hR = new mfem::HypreParMatrix(
122 0 : test_fespace.GetComm(), test_fespace.GlobalTrueVSize(), test_fespace.GlobalVSize(),
123 0 : test_fespace.Get().GetTrueDofOffsets(), test_fespace.Get().GetDofOffsets(),
124 0 : const_cast<mfem::SparseMatrix *>(test_fespace.GetRestrictionMatrix()));
125 0 : hypre_ParCSRMatrix *AP = hypre_ParCSRMatMat(hA, *P);
126 0 : RAP = std::make_unique<mfem::HypreParMatrix>(hypre_ParCSRMatMat(*hR, AP), true);
127 0 : hypre_ParCSRMatrixDestroy(AP);
128 0 : delete hR;
129 : }
130 :
131 0 : hypre_ParCSRMatrixDiag(hA) = hA_diag;
132 0 : hypre_ParCSRMatrixDestroy(hA);
133 0 : hypre_ParCSRMatrixSetNumNonzeros(*RAP);
134 0 : if (&trial_fespace == &test_fespace)
135 : {
136 : // Make sure that the first entry in each row is the diagonal one, for a square matrix.
137 0 : hypre_CSRMatrixReorder(hypre_ParCSRMatrixDiag((hypre_ParCSRMatrix *)*RAP));
138 : }
139 :
140 : // Eliminate boundary conditions on the assembled (square) matrix.
141 0 : if (&trial_fespace == &test_fespace)
142 : {
143 0 : RAP->EliminateBC(dbc_tdof_list, diag_policy);
144 : }
145 : else
146 : {
147 0 : MFEM_VERIFY(dbc_tdof_list.Size() == 0,
148 : "Essential BC elimination is only available for square ParOperator!");
149 : }
150 :
151 : return *RAP;
152 : }
153 :
154 0 : void ParOperator::AssembleDiagonal(Vector &diag) const
155 : {
156 0 : diag.UseDevice(true);
157 0 : if (RAP)
158 : {
159 0 : RAP->AssembleDiagonal(diag);
160 0 : return;
161 : }
162 :
163 : // For an AMR mesh, a convergent diagonal is assembled with |P|ᵀ dₗ, where |P| has
164 : // entry-wise absolute values of the conforming prolongation operator.
165 0 : MFEM_VERIFY(&trial_fespace == &test_fespace,
166 : "Diagonal assembly is only available for square ParOperator!");
167 : auto &lx = trial_fespace.GetLVector<Vector>();
168 0 : A->AssembleDiagonal(lx);
169 :
170 : // Parallel assemble and eliminate essential true dofs.
171 0 : const Operator *P = test_fespace.GetProlongationMatrix();
172 0 : if (const auto *hP = dynamic_cast<const mfem::HypreParMatrix *>(P))
173 : {
174 0 : hP->AbsMultTranspose(1.0, lx, 0.0, diag);
175 : }
176 : else
177 : {
178 0 : P->MultTranspose(lx, diag);
179 : }
180 :
181 : // Eliminate essential true dofs.
182 0 : if (dbc_tdof_list.Size())
183 : {
184 0 : if (diag_policy == DiagonalPolicy::DIAG_ONE)
185 : {
186 0 : linalg::SetSubVector(diag, dbc_tdof_list, 1.0);
187 : }
188 0 : else if (diag_policy == DiagonalPolicy::DIAG_ZERO)
189 : {
190 0 : linalg::SetSubVector(diag, dbc_tdof_list, 0.0);
191 : }
192 : }
193 : }
194 :
195 3 : void ParOperator::Mult(const Vector &x, Vector &y) const
196 : {
197 : MFEM_ASSERT(x.Size() == width && y.Size() == height,
198 : "Incompatible dimensions for ParOperator::Mult!");
199 3 : if (RAP)
200 : {
201 0 : RAP->Mult(x, y);
202 0 : return;
203 : }
204 :
205 3 : auto &lx = trial_fespace.GetLVector<Vector>();
206 3 : auto &ly = GetTestLVector();
207 3 : if (dbc_tdof_list.Size())
208 : {
209 0 : auto &tx = trial_fespace.GetTVector<Vector>();
210 0 : tx = x;
211 0 : linalg::SetSubVector(tx, dbc_tdof_list, 0.0);
212 0 : trial_fespace.GetProlongationMatrix()->Mult(tx, lx);
213 : }
214 : else
215 : {
216 3 : trial_fespace.GetProlongationMatrix()->Mult(x, lx);
217 : }
218 :
219 : // Apply the operator on the L-vector.
220 3 : A->Mult(lx, ly);
221 :
222 3 : RestrictionMatrixMult(ly, y);
223 3 : if (dbc_tdof_list.Size())
224 : {
225 0 : if (diag_policy == DiagonalPolicy::DIAG_ONE)
226 : {
227 0 : linalg::SetSubVector(y, dbc_tdof_list, x);
228 : }
229 0 : else if (diag_policy == DiagonalPolicy::DIAG_ZERO)
230 : {
231 0 : linalg::SetSubVector(y, dbc_tdof_list, 0.0);
232 : }
233 : }
234 : }
235 :
236 0 : void ParOperator::MultTranspose(const Vector &x, Vector &y) const
237 : {
238 : MFEM_ASSERT(x.Size() == height && y.Size() == width,
239 : "Incompatible dimensions for ParOperator::MultTranspose!");
240 0 : if (RAP)
241 : {
242 0 : RAP->MultTranspose(x, y);
243 0 : return;
244 : }
245 :
246 0 : auto &lx = trial_fespace.GetLVector<Vector>();
247 0 : auto &ly = GetTestLVector();
248 0 : if (dbc_tdof_list.Size())
249 : {
250 0 : auto &ty = test_fespace.GetTVector<Vector>();
251 0 : ty = x;
252 0 : linalg::SetSubVector(ty, dbc_tdof_list, 0.0);
253 0 : RestrictionMatrixMultTranspose(ty, ly);
254 : }
255 : else
256 : {
257 0 : RestrictionMatrixMultTranspose(x, ly);
258 : }
259 :
260 : // Apply the operator on the L-vector.
261 0 : A->MultTranspose(ly, lx);
262 :
263 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx, y);
264 0 : if (dbc_tdof_list.Size())
265 : {
266 0 : if (diag_policy == DiagonalPolicy::DIAG_ONE)
267 : {
268 0 : linalg::SetSubVector(y, dbc_tdof_list, x);
269 : }
270 0 : else if (diag_policy == DiagonalPolicy::DIAG_ZERO)
271 : {
272 0 : linalg::SetSubVector(y, dbc_tdof_list, 0.0);
273 : }
274 : }
275 : }
276 :
277 6 : void ParOperator::AddMult(const Vector &x, Vector &y, const double a) const
278 : {
279 : MFEM_ASSERT(x.Size() == width && y.Size() == height,
280 : "Incompatible dimensions for ParOperator::AddMult!");
281 6 : if (RAP)
282 : {
283 0 : RAP->AddMult(x, y, a);
284 0 : return;
285 : }
286 :
287 6 : auto &lx = trial_fespace.GetLVector<Vector>();
288 6 : auto &ly = GetTestLVector();
289 6 : if (dbc_tdof_list.Size())
290 : {
291 0 : auto &tx = trial_fespace.GetTVector<Vector>();
292 0 : tx = x;
293 0 : linalg::SetSubVector(tx, dbc_tdof_list, 0.0);
294 0 : trial_fespace.GetProlongationMatrix()->Mult(tx, lx);
295 : }
296 : else
297 : {
298 6 : trial_fespace.GetProlongationMatrix()->Mult(x, lx);
299 : }
300 :
301 : // Apply the operator on the L-vector.
302 6 : A->Mult(lx, ly);
303 :
304 6 : auto &ty = test_fespace.GetTVector<Vector>();
305 6 : RestrictionMatrixMult(ly, ty);
306 6 : if (dbc_tdof_list.Size())
307 : {
308 0 : if (diag_policy == DiagonalPolicy::DIAG_ONE)
309 : {
310 0 : linalg::SetSubVector(ty, dbc_tdof_list, x);
311 : }
312 0 : else if (diag_policy == DiagonalPolicy::DIAG_ZERO)
313 : {
314 0 : linalg::SetSubVector(ty, dbc_tdof_list, 0.0);
315 : }
316 : }
317 6 : y.Add(a, ty);
318 : }
319 :
320 0 : void ParOperator::AddMultTranspose(const Vector &x, Vector &y, const double a) const
321 : {
322 : MFEM_ASSERT(x.Size() == height && y.Size() == width,
323 : "Incompatible dimensions for ParOperator::AddMultTranspose!");
324 0 : if (RAP)
325 : {
326 0 : RAP->AddMultTranspose(x, y, a);
327 0 : return;
328 : }
329 :
330 0 : auto &lx = trial_fespace.GetLVector<Vector>();
331 0 : auto &ly = GetTestLVector();
332 0 : if (dbc_tdof_list.Size())
333 : {
334 0 : auto &ty = test_fespace.GetTVector<Vector>();
335 0 : ty = x;
336 0 : linalg::SetSubVector(ty, dbc_tdof_list, 0.0);
337 0 : RestrictionMatrixMultTranspose(ty, ly);
338 : }
339 : else
340 : {
341 0 : RestrictionMatrixMultTranspose(x, ly);
342 : }
343 :
344 : // Apply the operator on the L-vector.
345 0 : A->MultTranspose(ly, lx);
346 :
347 0 : auto &tx = trial_fespace.GetTVector<Vector>();
348 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx, tx);
349 0 : if (dbc_tdof_list.Size())
350 : {
351 0 : if (diag_policy == DiagonalPolicy::DIAG_ONE)
352 : {
353 0 : linalg::SetSubVector(tx, dbc_tdof_list, x);
354 : }
355 0 : else if (diag_policy == DiagonalPolicy::DIAG_ZERO)
356 : {
357 0 : linalg::SetSubVector(tx, dbc_tdof_list, 0.0);
358 : }
359 : }
360 0 : y.Add(a, tx);
361 : }
362 :
363 9 : void ParOperator::RestrictionMatrixMult(const Vector &ly, Vector &ty) const
364 : {
365 9 : if (!use_R)
366 : {
367 9 : test_fespace.GetProlongationMatrix()->MultTranspose(ly, ty);
368 : }
369 : else
370 : {
371 0 : test_fespace.GetRestrictionMatrix()->Mult(ly, ty);
372 : }
373 9 : }
374 :
375 0 : void ParOperator::RestrictionMatrixMultTranspose(const Vector &ty, Vector &ly) const
376 : {
377 0 : if (!use_R)
378 : {
379 0 : test_fespace.GetProlongationMatrix()->Mult(ty, ly);
380 : }
381 : else
382 : {
383 0 : test_fespace.GetRestrictionMatrix()->MultTranspose(ty, ly);
384 : }
385 0 : }
386 :
387 9 : Vector &ParOperator::GetTestLVector() const
388 : {
389 9 : return (&trial_fespace == &test_fespace) ? trial_fespace.GetLVector2<Vector>()
390 9 : : test_fespace.GetLVector<Vector>();
391 : }
392 :
393 9 : ComplexParOperator::ComplexParOperator(std::unique_ptr<Operator> &&dAr,
394 : std::unique_ptr<Operator> &&dAi, const Operator *pAr,
395 : const Operator *pAi,
396 : const FiniteElementSpace &trial_fespace,
397 : const FiniteElementSpace &test_fespace,
398 9 : bool test_restrict)
399 : : ComplexOperator(test_fespace.GetTrueVSize(), trial_fespace.GetTrueVSize()),
400 9 : data_A((dAr != nullptr || dAi != nullptr)
401 9 : ? std::make_unique<ComplexWrapperOperator>(std::move(dAr), std::move(dAi))
402 : : std::make_unique<ComplexWrapperOperator>(pAr, pAi)),
403 9 : A(data_A.get()), trial_fespace(trial_fespace), test_fespace(test_fespace),
404 9 : use_R(test_restrict), diag_policy(Operator::DiagonalPolicy::DIAG_ONE),
405 18 : RAPr(A->Real()
406 9 : ? std::make_unique<ParOperator>(*A->Real(), trial_fespace, test_fespace, use_R)
407 : : nullptr),
408 18 : RAPi(A->Imag()
409 9 : ? std::make_unique<ParOperator>(*A->Imag(), trial_fespace, test_fespace, use_R)
410 9 : : nullptr)
411 : {
412 : // We use the non-owning constructors for real and imaginary part ParOperators, since we
413 : // construct A as a ComplexWrapperOperator which has separate access to the real and
414 : // imaginary components.
415 9 : }
416 :
417 9 : ComplexParOperator::ComplexParOperator(std::unique_ptr<Operator> &&Ar,
418 : std::unique_ptr<Operator> &&Ai,
419 : const FiniteElementSpace &trial_fespace,
420 : const FiniteElementSpace &test_fespace,
421 9 : bool test_restrict)
422 : : ComplexParOperator(std::move(Ar), std::move(Ai), nullptr, nullptr, trial_fespace,
423 9 : test_fespace, test_restrict)
424 : {
425 9 : }
426 :
427 0 : ComplexParOperator::ComplexParOperator(const Operator *Ar, const Operator *Ai,
428 : const FiniteElementSpace &trial_fespace,
429 : const FiniteElementSpace &test_fespace,
430 0 : bool test_restrict)
431 0 : : ComplexParOperator(nullptr, nullptr, Ar, Ai, trial_fespace, test_fespace, test_restrict)
432 : {
433 0 : }
434 :
435 0 : void ComplexParOperator::SetEssentialTrueDofs(const mfem::Array<int> &tdof_list,
436 : Operator::DiagonalPolicy policy)
437 : {
438 0 : MFEM_VERIFY(policy == Operator::DiagonalPolicy::DIAG_ONE ||
439 : policy == Operator::DiagonalPolicy::DIAG_ZERO,
440 : "Essential boundary condition true dof elimination for ComplexParOperator "
441 : "supports only DiagonalPolicy::DIAG_ONE or DiagonalPolicy::DIAG_ZERO!");
442 0 : MFEM_VERIFY(
443 : policy != Operator::DiagonalPolicy::DIAG_ONE || RAPr,
444 : "DiagonalPolicy::DIAG_ONE specified for ComplexParOperator with no real part!");
445 0 : MFEM_VERIFY(height == width, "Set essential true dofs for both test and trial spaces "
446 : "for rectangular ComplexParOperator!");
447 : tdof_list.Read();
448 0 : dbc_tdof_list.MakeRef(tdof_list);
449 0 : diag_policy = policy;
450 0 : if (RAPr)
451 : {
452 0 : RAPr->SetEssentialTrueDofs(tdof_list, policy);
453 : }
454 0 : if (RAPi)
455 : {
456 0 : RAPi->SetEssentialTrueDofs(tdof_list, Operator::DiagonalPolicy::DIAG_ZERO);
457 : }
458 0 : }
459 :
460 0 : Operator::DiagonalPolicy ComplexParOperator::GetDiagonalPolicy() const
461 : {
462 0 : MFEM_VERIFY(dbc_tdof_list.Size() > 0,
463 : "There is no DiagonalPolicy if no essential dofs have been set!");
464 0 : return diag_policy;
465 : }
466 :
467 0 : void ComplexParOperator::AssembleDiagonal(ComplexVector &diag) const
468 : {
469 0 : diag.UseDevice(true);
470 : diag = 0.0;
471 0 : if (RAPr)
472 : {
473 0 : RAPr->AssembleDiagonal(diag.Real());
474 : }
475 0 : if (RAPi)
476 : {
477 0 : RAPi->AssembleDiagonal(diag.Imag());
478 : }
479 0 : }
480 :
481 3 : void ComplexParOperator::Mult(const ComplexVector &x, ComplexVector &y) const
482 : {
483 : MFEM_ASSERT(x.Size() == width && y.Size() == height,
484 : "Incompatible dimensions for ComplexParOperator::Mult!");
485 :
486 3 : auto &lx = trial_fespace.GetLVector<ComplexVector>();
487 3 : auto &ly = GetTestLVector();
488 3 : if (dbc_tdof_list.Size())
489 : {
490 0 : auto &tx = trial_fespace.GetTVector<ComplexVector>();
491 : tx = x;
492 0 : linalg::SetSubVector(tx, dbc_tdof_list, 0.0);
493 0 : trial_fespace.GetProlongationMatrix()->Mult(tx.Real(), lx.Real());
494 0 : trial_fespace.GetProlongationMatrix()->Mult(tx.Imag(), lx.Imag());
495 : }
496 : else
497 : {
498 3 : trial_fespace.GetProlongationMatrix()->Mult(x.Real(), lx.Real());
499 3 : trial_fespace.GetProlongationMatrix()->Mult(x.Imag(), lx.Imag());
500 : }
501 :
502 : // Apply the operator on the L-vector.
503 3 : A->Mult(lx, ly);
504 :
505 3 : RestrictionMatrixMult(ly, y);
506 3 : if (dbc_tdof_list.Size())
507 : {
508 0 : if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
509 : {
510 0 : linalg::SetSubVector(y, dbc_tdof_list, x);
511 : }
512 0 : else if (diag_policy == Operator::DiagonalPolicy::DIAG_ZERO)
513 : {
514 0 : linalg::SetSubVector(y, dbc_tdof_list, 0.0);
515 : }
516 : }
517 3 : }
518 :
519 0 : void ComplexParOperator::MultTranspose(const ComplexVector &x, ComplexVector &y) const
520 : {
521 : MFEM_ASSERT(x.Size() == height && y.Size() == width,
522 : "Incompatible dimensions for ComplexParOperator::MultTranspose!");
523 :
524 0 : auto &lx = trial_fespace.GetLVector<ComplexVector>();
525 0 : auto &ly = GetTestLVector();
526 0 : if (dbc_tdof_list.Size())
527 : {
528 0 : auto &ty = test_fespace.GetTVector<ComplexVector>();
529 : ty = x;
530 0 : linalg::SetSubVector(ty, dbc_tdof_list, 0.0);
531 0 : RestrictionMatrixMultTranspose(ty, ly);
532 : }
533 : else
534 : {
535 0 : RestrictionMatrixMultTranspose(x, ly);
536 : }
537 :
538 : // Apply the operator on the L-vector.
539 0 : A->MultTranspose(ly, lx);
540 :
541 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), y.Real());
542 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), y.Imag());
543 0 : if (dbc_tdof_list.Size())
544 : {
545 0 : if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
546 : {
547 0 : linalg::SetSubVector(y, dbc_tdof_list, x);
548 : }
549 0 : else if (diag_policy == Operator::DiagonalPolicy::DIAG_ZERO)
550 : {
551 0 : linalg::SetSubVector(y, dbc_tdof_list, 0.0);
552 : }
553 : }
554 0 : }
555 :
556 0 : void ComplexParOperator::MultHermitianTranspose(const ComplexVector &x,
557 : ComplexVector &y) const
558 : {
559 : MFEM_ASSERT(x.Size() == height && y.Size() == width,
560 : "Incompatible dimensions for ComplexParOperator::MultHermitianTranspose!");
561 :
562 0 : auto &lx = trial_fespace.GetLVector<ComplexVector>();
563 0 : auto &ly = GetTestLVector();
564 0 : if (dbc_tdof_list.Size())
565 : {
566 0 : auto &ty = test_fespace.GetTVector<ComplexVector>();
567 : ty = x;
568 0 : linalg::SetSubVector(ty, dbc_tdof_list, 0.0);
569 0 : RestrictionMatrixMultTranspose(ty, ly);
570 : }
571 : else
572 : {
573 0 : RestrictionMatrixMultTranspose(x, ly);
574 : }
575 :
576 : // Apply the operator on the L-vector.
577 0 : A->MultHermitianTranspose(ly, lx);
578 :
579 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), y.Real());
580 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), y.Imag());
581 0 : if (dbc_tdof_list.Size())
582 : {
583 0 : if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
584 : {
585 0 : linalg::SetSubVector(y, dbc_tdof_list, x);
586 : }
587 0 : else if (diag_policy == Operator::DiagonalPolicy::DIAG_ZERO)
588 : {
589 0 : linalg::SetSubVector(y, dbc_tdof_list, 0.0);
590 : }
591 : }
592 0 : }
593 :
594 6 : void ComplexParOperator::AddMult(const ComplexVector &x, ComplexVector &y,
595 : const std::complex<double> a) const
596 : {
597 : MFEM_ASSERT(x.Size() == width && y.Size() == height,
598 : "Incompatible dimensions for ComplexParOperator::AddMult!");
599 :
600 6 : auto &lx = trial_fespace.GetLVector<ComplexVector>();
601 6 : auto &ly = GetTestLVector();
602 6 : if (dbc_tdof_list.Size())
603 : {
604 0 : auto &tx = trial_fespace.GetTVector<ComplexVector>();
605 : tx = x;
606 0 : linalg::SetSubVector(tx, dbc_tdof_list, 0.0);
607 0 : trial_fespace.GetProlongationMatrix()->Mult(tx.Real(), lx.Real());
608 0 : trial_fespace.GetProlongationMatrix()->Mult(tx.Imag(), lx.Imag());
609 : }
610 : else
611 : {
612 6 : trial_fespace.GetProlongationMatrix()->Mult(x.Real(), lx.Real());
613 6 : trial_fespace.GetProlongationMatrix()->Mult(x.Imag(), lx.Imag());
614 : }
615 :
616 : // Apply the operator on the L-vector.
617 6 : A->Mult(lx, ly);
618 :
619 6 : auto &ty = test_fespace.GetTVector<ComplexVector>();
620 6 : RestrictionMatrixMult(ly, ty);
621 6 : if (dbc_tdof_list.Size())
622 : {
623 0 : if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
624 : {
625 0 : linalg::SetSubVector(ty, dbc_tdof_list, x);
626 : }
627 0 : else if (diag_policy == Operator::DiagonalPolicy::DIAG_ZERO)
628 : {
629 0 : linalg::SetSubVector(ty, dbc_tdof_list, 0.0);
630 : }
631 : }
632 6 : y.AXPY(a, ty);
633 6 : }
634 :
635 0 : void ComplexParOperator::AddMultTranspose(const ComplexVector &x, ComplexVector &y,
636 : const std::complex<double> a) const
637 : {
638 : MFEM_ASSERT(x.Size() == height && y.Size() == width,
639 : "Incompatible dimensions for ComplexParOperator::AddMultTranspose!");
640 :
641 0 : auto &lx = trial_fespace.GetLVector<ComplexVector>();
642 0 : auto &ly = GetTestLVector();
643 0 : if (dbc_tdof_list.Size())
644 : {
645 0 : auto &ty = test_fespace.GetTVector<ComplexVector>();
646 : ty = x;
647 0 : linalg::SetSubVector(ty, dbc_tdof_list, 0.0);
648 0 : RestrictionMatrixMultTranspose(ty, ly);
649 : }
650 : else
651 : {
652 0 : RestrictionMatrixMultTranspose(x, ly);
653 : }
654 :
655 : // Apply the operator on the L-vector.
656 0 : A->MultTranspose(ly, lx);
657 :
658 0 : auto &tx = trial_fespace.GetTVector<ComplexVector>();
659 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real());
660 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag());
661 0 : if (dbc_tdof_list.Size())
662 : {
663 0 : if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
664 : {
665 0 : linalg::SetSubVector(tx, dbc_tdof_list, x);
666 : }
667 0 : else if (diag_policy == Operator::DiagonalPolicy::DIAG_ZERO)
668 : {
669 0 : linalg::SetSubVector(tx, dbc_tdof_list, 0.0);
670 : }
671 : }
672 0 : y.AXPY(a, tx);
673 0 : }
674 :
675 0 : void ComplexParOperator::AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
676 : const std::complex<double> a) const
677 : {
678 : MFEM_ASSERT(x.Size() == height && y.Size() == width,
679 : "Incompatible dimensions for ComplexParOperator::AddMultHermitianTranspose!");
680 :
681 0 : auto &lx = trial_fespace.GetLVector<ComplexVector>();
682 0 : auto &ly = GetTestLVector();
683 0 : if (dbc_tdof_list.Size())
684 : {
685 0 : auto &ty = test_fespace.GetTVector<ComplexVector>();
686 : ty = x;
687 0 : linalg::SetSubVector(ty, dbc_tdof_list, 0.0);
688 0 : RestrictionMatrixMultTranspose(ty, ly);
689 : }
690 : else
691 : {
692 0 : RestrictionMatrixMultTranspose(x, ly);
693 : }
694 :
695 : // Apply the operator on the L-vector.
696 0 : A->MultHermitianTranspose(ly, lx);
697 :
698 0 : auto &tx = trial_fespace.GetTVector<ComplexVector>();
699 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real());
700 0 : trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag());
701 0 : if (dbc_tdof_list.Size())
702 : {
703 0 : if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
704 : {
705 0 : linalg::SetSubVector(tx, dbc_tdof_list, x);
706 : }
707 0 : else if (diag_policy == Operator::DiagonalPolicy::DIAG_ZERO)
708 : {
709 0 : linalg::SetSubVector(tx, dbc_tdof_list, 0.0);
710 : }
711 : }
712 0 : y.AXPY(a, tx);
713 0 : }
714 :
715 9 : void ComplexParOperator::RestrictionMatrixMult(const ComplexVector &ly,
716 : ComplexVector &ty) const
717 : {
718 9 : if (!use_R)
719 : {
720 9 : test_fespace.GetProlongationMatrix()->MultTranspose(ly.Real(), ty.Real());
721 9 : test_fespace.GetProlongationMatrix()->MultTranspose(ly.Imag(), ty.Imag());
722 : }
723 : else
724 : {
725 0 : test_fespace.GetRestrictionMatrix()->Mult(ly.Real(), ty.Real());
726 0 : test_fespace.GetRestrictionMatrix()->Mult(ly.Imag(), ty.Imag());
727 : }
728 9 : }
729 :
730 0 : void ComplexParOperator::RestrictionMatrixMultTranspose(const ComplexVector &ty,
731 : ComplexVector &ly) const
732 : {
733 0 : if (!use_R)
734 : {
735 0 : test_fespace.GetProlongationMatrix()->Mult(ty.Real(), ly.Real());
736 0 : test_fespace.GetProlongationMatrix()->Mult(ty.Imag(), ly.Imag());
737 : }
738 : else
739 : {
740 0 : test_fespace.GetRestrictionMatrix()->MultTranspose(ty.Real(), ly.Real());
741 0 : test_fespace.GetRestrictionMatrix()->MultTranspose(ty.Imag(), ly.Imag());
742 : }
743 0 : }
744 :
745 9 : ComplexVector &ComplexParOperator::GetTestLVector() const
746 : {
747 9 : return (&trial_fespace == &test_fespace) ? trial_fespace.GetLVector2<ComplexVector>()
748 9 : : test_fespace.GetLVector<ComplexVector>();
749 : }
750 :
751 : // Helper that checks if two containers (Vector or Array<T>) are actually references to the
752 : // same underlying data.
753 : template <typename C>
754 0 : bool ReferencesSameMemory(const C &c1, const C &c2)
755 : {
756 : const auto &m1 = c1.GetMemory();
757 : const auto &m2 = c2.GetMemory();
758 0 : return (m1.HostIsValid() && m2.HostIsValid() && c1.HostRead() == c2.HostRead()) ||
759 0 : (m1.DeviceIsValid() && m2.DeviceIsValid() && c1.Read() == c2.Read());
760 : }
761 :
762 : // Combine a collection of ParOperator into a weighted summation. If set_essential is true,
763 : // extract the essential dofs from the operator array, and apply to the summed operator.
764 : template <std::size_t N>
765 : std::unique_ptr<ParOperator>
766 3 : BuildParSumOperator(const std::array<double, N> &coeff,
767 : const std::array<const ParOperator *, N> &ops, bool set_essential)
768 : {
769 : auto it = std::find_if(ops.begin(), ops.end(), [](auto p) { return p != nullptr; });
770 3 : MFEM_VERIFY(it != ops.end(),
771 : "BuildParSumOperator requires at least one valid ParOperator!");
772 3 : const auto first_op = *it;
773 : const auto &fespace = first_op->TrialFiniteElementSpace();
774 12 : MFEM_VERIFY(
775 : std::all_of(ops.begin(), ops.end(), [&fespace](auto p)
776 : { return p == nullptr || &p->TrialFiniteElementSpace() == &fespace; }),
777 : "All ComplexParOperators must have the same FiniteElementSpace!");
778 :
779 3 : auto sum = std::make_unique<SumOperator>(first_op->LocalOperator().Height(),
780 3 : first_op->LocalOperator().Width());
781 12 : for (std::size_t i = 0; i < coeff.size(); i++)
782 : {
783 9 : if (ops[i] && coeff[i] != 0)
784 : {
785 6 : sum->AddOperator(ops[i]->LocalOperator(), coeff[i]);
786 : }
787 : }
788 :
789 3 : auto O = std::make_unique<ParOperator>(std::move(sum), fespace);
790 3 : if (set_essential)
791 : {
792 : // Extract essential dof pointer from first operator with one.
793 : auto it_ess = std::find_if(ops.begin(), ops.end(), [](auto p)
794 9 : { return p != nullptr && p->GetEssentialTrueDofs(); });
795 3 : if (it_ess == ops.end())
796 : {
797 3 : return O;
798 : }
799 0 : const auto *ess_dofs = (*it_ess)->GetEssentialTrueDofs();
800 :
801 : // Check other existant essential dof arrays are references.
802 0 : MFEM_VERIFY(std::all_of(ops.begin(), ops.end(),
803 : [&](auto p)
804 : {
805 : if (p == nullptr)
806 : {
807 : return true;
808 : }
809 : auto p_ess_dofs = p->GetEssentialTrueDofs();
810 : return p_ess_dofs == nullptr ||
811 : ReferencesSameMemory(*ess_dofs, *p_ess_dofs);
812 : }),
813 : "If essential dofs are set, all suboperators must agree on them!");
814 :
815 : // Use implied ordering of enumeration.
816 0 : Operator::DiagonalPolicy policy = Operator::DiagonalPolicy::DIAG_ZERO;
817 0 : for (auto p : ops)
818 : {
819 0 : policy = (p && p->GetEssentialTrueDofs()) ? std::max(policy, p->GetDiagonalPolicy())
820 : : policy;
821 : }
822 0 : O->SetEssentialTrueDofs(*ess_dofs, policy);
823 : }
824 :
825 : return O;
826 : }
827 :
828 : // Combine a collection of ComplexParOperator into a weighted summation. If set_essential is
829 : // true, extract the essential dofs from the operator array, and apply to the summed
830 : // operator.
831 : template <std::size_t N>
832 : std::unique_ptr<ComplexParOperator>
833 3 : BuildParSumOperator(const std::array<std::complex<double>, N> &coeff,
834 : const std::array<const ComplexParOperator *, N> &ops,
835 : bool set_essential)
836 : {
837 : auto it = std::find_if(ops.begin(), ops.end(), [](auto p) { return p != nullptr; });
838 3 : MFEM_VERIFY(it != ops.end(),
839 : "BuildParSumOperator requires at least one valid ComplexParOperator!");
840 3 : const auto first_op = *it;
841 : const auto &fespace = first_op->TrialFiniteElementSpace();
842 12 : MFEM_VERIFY(
843 : std::all_of(ops.begin(), ops.end(), [&fespace](auto p)
844 : { return p == nullptr || &p->TrialFiniteElementSpace() == &fespace; }),
845 : "All ComplexParOperators must have the same FiniteElementSpace!");
846 :
847 3 : auto sumr = std::make_unique<SumOperator>(first_op->LocalOperator().Height(),
848 3 : first_op->LocalOperator().Width());
849 3 : auto sumi = std::make_unique<SumOperator>(first_op->LocalOperator().Height(),
850 3 : first_op->LocalOperator().Width());
851 12 : for (std::size_t i = 0; i < coeff.size(); i++)
852 : {
853 9 : if (ops[i] && coeff[i].real() != 0)
854 : {
855 6 : if (ops[i]->LocalOperator().Real())
856 : {
857 6 : sumr->AddOperator(*ops[i]->LocalOperator().Real(), coeff[i].real());
858 : }
859 6 : if (ops[i]->LocalOperator().Imag())
860 : {
861 6 : sumi->AddOperator(*ops[i]->LocalOperator().Imag(), coeff[i].real());
862 : }
863 : }
864 9 : if (ops[i] && coeff[i].imag() != 0)
865 : {
866 6 : if (ops[i]->LocalOperator().Imag())
867 : {
868 6 : sumr->AddOperator(*ops[i]->LocalOperator().Imag(), -coeff[i].imag());
869 : }
870 6 : if (ops[i]->LocalOperator().Real())
871 : {
872 6 : sumi->AddOperator(*ops[i]->LocalOperator().Real(), coeff[i].imag());
873 : }
874 : }
875 : }
876 3 : auto O = std::make_unique<ComplexParOperator>(std::move(sumr), std::move(sumi), fespace);
877 3 : if (set_essential)
878 : {
879 : // Extract essential dof pointer from first operator with one.
880 : auto it_ess = std::find_if(ops.begin(), ops.end(), [](auto p)
881 9 : { return p != nullptr && p->GetEssentialTrueDofs(); });
882 3 : if (it_ess == ops.end())
883 : {
884 3 : return O;
885 : }
886 0 : const auto *ess_dofs = (*it_ess)->GetEssentialTrueDofs();
887 :
888 : // Check other existant essential dof arrays are references.
889 0 : MFEM_VERIFY(std::all_of(ops.begin(), ops.end(),
890 : [&](auto p)
891 : {
892 : if (p == nullptr)
893 : {
894 : return true;
895 : }
896 : auto p_ess_dofs = p->GetEssentialTrueDofs();
897 : return p_ess_dofs == nullptr ||
898 : ReferencesSameMemory(*ess_dofs, *p_ess_dofs);
899 : }),
900 : "If essential dofs are set, all suboperators must agree on them!");
901 :
902 : // Use implied ordering of enumeration.
903 0 : Operator::DiagonalPolicy policy = Operator::DiagonalPolicy::DIAG_ZERO;
904 0 : for (auto p : ops)
905 : {
906 0 : policy = (p && p->GetEssentialTrueDofs()) ? std::max(policy, p->GetDiagonalPolicy())
907 : : policy;
908 : }
909 0 : O->SetEssentialTrueDofs(*ess_dofs, policy);
910 : }
911 : return O;
912 : }
913 :
914 : // TODO: replace with std::to_array in c++20.
915 : namespace detail
916 : {
917 : // Helper for conversion to std::array.
918 : template <class T, std::size_t N, std::size_t... I>
919 : constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&&a)[N],
920 : std::index_sequence<I...>)
921 : {
922 6 : return {{std::move(a[I])...}};
923 : }
924 : } // namespace detail
925 :
926 : template <class T, std::size_t N>
927 : constexpr std::array<std::remove_cv_t<T>, N> to_array(T (&&a)[N])
928 : {
929 : return detail::to_array_impl(std::move(a), std::make_index_sequence<N>{});
930 : }
931 :
932 : template <std::size_t N>
933 : std::unique_ptr<ComplexParOperator>
934 : BuildParSumOperator(std::complex<double> (&&coeff_in)[N],
935 : const ComplexParOperator *(&&ops_in)[N], bool set_essential)
936 : {
937 : return BuildParSumOperator(to_array<std::complex<double>>(std::move(coeff_in)),
938 : to_array<const ComplexParOperator *>(std::move(ops_in)),
939 : set_essential);
940 : }
941 :
942 : template <std::size_t N, typename ScalarType, typename OperType>
943 : std::unique_ptr<std::conditional_t<std::is_base_of_v<ComplexOperator, OperType>,
944 : ComplexParOperator, ParOperator>>
945 6 : BuildParSumOperator(ScalarType (&&coeff_in)[N], const OperType *(&&ops_in)[N],
946 : bool set_essential)
947 : {
948 : using ParOperType =
949 : typename std::conditional_t<std::is_base_of_v<ComplexOperator, OperType>,
950 : ComplexParOperator, ParOperator>;
951 :
952 : std::array<const ParOperType *, N> par_ops;
953 6 : std::transform(ops_in, ops_in + N, par_ops.begin(),
954 18 : [](const OperType *op) { return dynamic_cast<const ParOperType *>(op); });
955 :
956 6 : return BuildParSumOperator(to_array<ScalarType>(std::move(coeff_in)), std::move(par_ops),
957 6 : set_essential);
958 : }
959 :
960 : // Explicit instantiation.
961 : template std::unique_ptr<ParOperator> BuildParSumOperator(double (&&)[2],
962 : const Operator *(&&)[2], bool);
963 : template std::unique_ptr<ParOperator> BuildParSumOperator(double (&&)[3],
964 : const Operator *(&&)[3], bool);
965 : template std::unique_ptr<ParOperator> BuildParSumOperator(double (&&)[4],
966 : const Operator *(&&)[4], bool);
967 : template std::unique_ptr<ComplexParOperator>
968 : BuildParSumOperator(std::complex<double> (&&)[2], const ComplexOperator *(&&)[2], bool);
969 : template std::unique_ptr<ComplexParOperator>
970 : BuildParSumOperator(std::complex<double> (&&)[3], const ComplexOperator *(&&)[3], bool);
971 : template std::unique_ptr<ComplexParOperator>
972 : BuildParSumOperator(std::complex<double> (&&)[4], const ComplexOperator *(&&)[4], bool);
973 :
974 : } // namespace palace
|