Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #include "solver.hpp"
5 :
6 : #include "linalg/rap.hpp"
7 :
8 : namespace palace
9 : {
10 :
11 : template <>
12 0 : void MfemWrapperSolver<Operator>::SetOperator(const Operator &op)
13 : {
14 : // Operator is always assembled as a HypreParMatrix.
15 0 : if (const auto *hA = dynamic_cast<const mfem::HypreParMatrix *>(&op))
16 : {
17 0 : pc->SetOperator(*hA);
18 : }
19 : else
20 : {
21 0 : const auto *PtAP = dynamic_cast<const ParOperator *>(&op);
22 0 : MFEM_VERIFY(PtAP,
23 : "MfemWrapperSolver must be able to construct a HypreParMatrix operator!");
24 0 : pc->SetOperator(!save_assembled ? *PtAP->StealParallelAssemble()
25 0 : : PtAP->ParallelAssemble());
26 : }
27 0 : this->height = op.Height();
28 0 : this->width = op.Width();
29 0 : }
30 :
31 : template <>
32 0 : void MfemWrapperSolver<ComplexOperator>::SetOperator(const ComplexOperator &op)
33 : {
34 : // Assemble the real and imaginary parts, then add.
35 : // XX TODO: Test complex matrix assembly if coarse solve supports it.
36 0 : const mfem::HypreParMatrix *hAr = dynamic_cast<const mfem::HypreParMatrix *>(op.Real());
37 0 : const mfem::HypreParMatrix *hAi = dynamic_cast<const mfem::HypreParMatrix *>(op.Imag());
38 : const ParOperator *PtAPr = nullptr, *PtAPi = nullptr;
39 0 : if (op.Real() && !hAr)
40 : {
41 0 : PtAPr = dynamic_cast<const ParOperator *>(op.Real());
42 0 : MFEM_VERIFY(PtAPr,
43 : "MfemWrapperSolver must be able to construct a HypreParMatrix operator!");
44 0 : hAr = &PtAPr->ParallelAssemble();
45 : }
46 0 : if (op.Imag() && !hAi)
47 : {
48 0 : PtAPi = dynamic_cast<const ParOperator *>(op.Imag());
49 0 : MFEM_VERIFY(PtAPi,
50 : "MfemWrapperSolver must be able to construct a HypreParMatrix operator!");
51 0 : hAi = &PtAPi->ParallelAssemble();
52 : }
53 0 : if (hAr && hAi)
54 : {
55 0 : if (complex_matrix)
56 : {
57 : // A = [Ar, -Ai]
58 : // [Ai, Ar]
59 : mfem::Array2D<const mfem::HypreParMatrix *> blocks(2, 2);
60 : mfem::Array2D<double> block_coeffs(2, 2);
61 0 : blocks(0, 0) = hAr;
62 0 : blocks(0, 1) = hAi;
63 0 : blocks(1, 0) = hAi;
64 0 : blocks(1, 1) = hAr;
65 0 : block_coeffs(0, 0) = 1.0;
66 0 : block_coeffs(0, 1) = -1.0;
67 0 : block_coeffs(1, 0) = 1.0;
68 0 : block_coeffs(1, 1) = 1.0;
69 0 : A.reset(mfem::HypreParMatrixFromBlocks(blocks, &block_coeffs));
70 : }
71 : else
72 : {
73 : // A = Ar + Ai.
74 0 : A.reset(mfem::Add(1.0, *hAr, 1.0, *hAi));
75 : }
76 0 : if (PtAPr)
77 : {
78 : PtAPr->StealParallelAssemble();
79 : }
80 0 : if (PtAPi)
81 : {
82 : PtAPi->StealParallelAssemble();
83 : }
84 0 : pc->SetOperator(*A);
85 0 : if (!save_assembled)
86 : {
87 : A.reset();
88 : }
89 : }
90 0 : else if (hAr)
91 : {
92 0 : pc->SetOperator(*hAr);
93 0 : if (PtAPr && !save_assembled)
94 : {
95 : PtAPr->StealParallelAssemble();
96 : }
97 : }
98 0 : else if (hAi)
99 : {
100 0 : pc->SetOperator(*hAi);
101 0 : if (PtAPi && !save_assembled)
102 : {
103 : PtAPi->StealParallelAssemble();
104 : }
105 : }
106 : else
107 : {
108 0 : MFEM_ABORT("Empty ComplexOperator for MfemWrapperSolver!");
109 : }
110 0 : this->height = op.Height();
111 0 : this->width = op.Width();
112 0 : }
113 :
114 : template <>
115 0 : void MfemWrapperSolver<Operator>::Mult(const Vector &x, Vector &y) const
116 : {
117 0 : pc->Mult(x, y);
118 0 : }
119 :
120 : template <>
121 0 : void MfemWrapperSolver<ComplexOperator>::Mult(const ComplexVector &x,
122 : ComplexVector &y) const
123 : {
124 0 : if (pc->Height() == x.Size())
125 : {
126 : mfem::Array<const Vector *> X(2);
127 : mfem::Array<Vector *> Y(2);
128 0 : X[0] = &x.Real();
129 0 : X[1] = &x.Imag();
130 0 : Y[0] = &y.Real();
131 0 : Y[1] = &y.Imag();
132 0 : pc->ArrayMult(X, Y);
133 : }
134 : else
135 : {
136 : const int Nx = x.Size(), Ny = y.Size();
137 0 : Vector X(2 * Nx), Y(2 * Ny), yr, yi;
138 : X.UseDevice(true);
139 : Y.UseDevice(true);
140 : yr.UseDevice(true);
141 : yi.UseDevice(true);
142 0 : linalg::SetSubVector(X, 0, x.Real());
143 0 : linalg::SetSubVector(X, Nx, x.Imag());
144 0 : pc->Mult(X, Y);
145 : Y.ReadWrite();
146 : yr.MakeRef(Y, 0, Ny);
147 : yi.MakeRef(Y, Ny, Ny);
148 0 : y.Real() = yr;
149 0 : y.Imag() = yi;
150 : }
151 0 : }
152 :
153 : } // namespace palace
|