Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #include "densematrix.hpp"
5 :
6 : #include <functional>
7 : #include <limits>
8 : #include <mfem.hpp>
9 : #include <mfem/linalg/kernels.hpp>
10 :
11 : namespace palace
12 : {
13 :
14 : namespace
15 : {
16 :
17 : // Compute matrix functions for symmetric real-valued 2x2 or 3x3 matrices. Returns the
18 : // matrix U * f(Λ) * U' for input U * Λ * U'.
19 : // Reference: Deledalle et al., Closed-form expressions of the eigen decomposition of 2x2
20 : // and 3x3 Hermitian matrices, HAL hal-01501221 (2017).
21 114 : mfem::DenseMatrix MatrixFunction(const mfem::DenseMatrix &M,
22 : const std::function<double(const double &)> &functor)
23 : {
24 : MFEM_ASSERT(M.Height() == M.Width(),
25 : "MatrixFunction only available for square matrices!");
26 : const auto N = M.Height();
27 : constexpr auto tol = 10.0 * std::numeric_limits<double>::epsilon();
28 456 : for (int i = 0; i < N; i++)
29 : {
30 684 : for (int j = i + 1; j < N; j++)
31 : {
32 342 : MFEM_VERIFY(std::abs(M(i, j) - M(j, i)) < tol,
33 : "MatrixFunction only available for symmetric matrices ("
34 : << M(i, j) << " != " << M(j, i) << ")!");
35 : }
36 : }
37 114 : mfem::DenseMatrix Mout(N, N);
38 114 : Mout = 0.0;
39 114 : if (N == 2)
40 : {
41 0 : MFEM_ABORT("2x2 MatrixFunction is not implemented yet!");
42 : }
43 114 : else if (N == 3)
44 : {
45 : // Need to specialize based on the number of zeros and their locations.
46 : const auto &a = M(0, 0), &b = M(1, 1), &c = M(2, 2);
47 : const auto &d = M(0, 1), &e = M(1, 2), &f = M(0, 2);
48 114 : const bool d_non_zero = std::abs(d) > tol;
49 114 : const bool e_non_zero = std::abs(e) > tol;
50 114 : const bool f_non_zero = std::abs(f) > tol;
51 114 : if (!d_non_zero && !e_non_zero && !f_non_zero)
52 : {
53 : // a 0 0
54 : // 0 b 0
55 : // 0 0 c
56 456 : for (int i = 0; i < 3; i++)
57 : {
58 342 : Mout(i, i) = functor(M(i, i));
59 : }
60 : return Mout;
61 : }
62 0 : if (d_non_zero && !e_non_zero && !f_non_zero)
63 : {
64 : // a d 0
65 : // d b 0
66 : // 0 0 c
67 0 : const double disc = std::sqrt(a * a - 2.0 * a * b + b * b + 4.0 * d * d);
68 0 : const double lambda1 = c;
69 0 : const double lambda2 = (a + b - disc) / 2.0;
70 0 : const double lambda3 = (a + b + disc) / 2.0;
71 0 : const mfem::Vector v1{{0.0, 0.0, 1.0}};
72 0 : const mfem::Vector v2{{-(-a + b + disc) / (2.0 * d), 1.0, 0.0}};
73 0 : const mfem::Vector v3{{-(-a + b - disc) / (2.0 * d), 1.0, 0.0}};
74 0 : AddMult_a_VVt(functor(lambda1), v1, Mout);
75 0 : AddMult_a_VVt(functor(lambda2), v2, Mout);
76 0 : AddMult_a_VVt(functor(lambda3), v3, Mout);
77 : return Mout;
78 : }
79 0 : if (!d_non_zero && e_non_zero && !f_non_zero)
80 : {
81 : // a 0 0
82 : // 0 b e
83 : // 0 e c
84 0 : const double disc = std::sqrt(b * b - 2.0 * b * c + c * c + 4.0 * e * e);
85 0 : const double lambda1 = a;
86 0 : const double lambda2 = 0.5 * (b + c - disc);
87 0 : const double lambda3 = 0.5 * (b + c + disc);
88 0 : const mfem::Vector v1{{1.0, 0.0, 0.0}};
89 0 : const mfem::Vector v2{{0.0, -(-b + c + disc) / (2.0 * e), 1.0}};
90 0 : const mfem::Vector v3{{0.0, -(-b + c - disc) / (2.0 * e), 1.0}};
91 0 : AddMult_a_VVt(functor(lambda1), v1, Mout);
92 0 : AddMult_a_VVt(functor(lambda2), v2, Mout);
93 0 : AddMult_a_VVt(functor(lambda3), v3, Mout);
94 : return Mout;
95 : }
96 0 : if (!d_non_zero && !e_non_zero && f_non_zero)
97 : {
98 : // a 0 f
99 : // 0 b 0
100 : // f 0 c
101 0 : const double disc = std::sqrt(a * a - 2.0 * a * c + c * c + 4.0 * f * f);
102 0 : const double lambda1 = b;
103 0 : const double lambda2 = 0.5 * (a + c - disc);
104 0 : const double lambda3 = 0.5 * (a + c + disc);
105 0 : const mfem::Vector v1{{0.0, 1.0, 0.0}};
106 0 : const mfem::Vector v2{{-(-a + c + disc) / (2.0 * f), 0.0, 1.0}};
107 0 : const mfem::Vector v3{{-(-a + c - disc) / (2.0 * f), 0.0, 1.0}};
108 0 : AddMult_a_VVt(functor(lambda1), v1, Mout);
109 0 : AddMult_a_VVt(functor(lambda2), v2, Mout);
110 0 : AddMult_a_VVt(functor(lambda3), v3, Mout);
111 : return Mout;
112 : }
113 0 : if ((!d_non_zero && e_non_zero && f_non_zero) ||
114 0 : (d_non_zero && !e_non_zero && f_non_zero) ||
115 0 : (d_non_zero && e_non_zero && !f_non_zero))
116 : {
117 0 : MFEM_ABORT("This nonzero pattern is not currently supported for MatrixFunction!");
118 : }
119 : // General case for all nonzero:
120 : // a d f
121 : // d b e
122 : // f e c
123 0 : const double a2 = a * a, b2 = b * b, c2 = c * c, d2 = d * d, e2 = e * e, f2 = f * f;
124 0 : const double a2mbmc = 2.0 * a - b - c;
125 0 : const double b2mamc = 2.0 * b - a - c;
126 0 : const double c2mamb = 2.0 * c - a - b;
127 0 : const double x1 = a2 + b2 + c2 - a * b - b * c + 3.0 * (d2 + e2 + f2);
128 0 : const double x2 = -(a2mbmc * b2mamc * c2mamb) +
129 0 : 9.0 * (c2mamb * d2 + b2mamc * f2 + a2mbmc * e2) - 54.0 * d * e * f;
130 0 : const double phi = std::atan2(std::sqrt(4.0 * x1 * x1 * x1 - x2 * x2), x2);
131 0 : const double lambda1 = (a + b + c - 2.0 * std::sqrt(x1) * std::cos(phi / 3.0)) / 3.0;
132 : const double lambda2 =
133 0 : (a + b + c + 2.0 * std::sqrt(x1) * std::cos((phi - M_PI) / 3.0)) / 3.0;
134 : const double lambda3 =
135 0 : (a + b + c + 2.0 * std::sqrt(x1) * std::cos((phi + M_PI) / 3.0)) / 3.0;
136 :
137 0 : auto SafeDivide = [&](double x, double y)
138 : {
139 0 : if (std::abs(x) <= tol)
140 : {
141 : return 0.0;
142 : }
143 0 : if (std::abs(x) >= tol && std::abs(y) <= tol)
144 : {
145 0 : MFEM_ABORT("Logic error: Zero denominator with nonzero numerator!");
146 : return 0.0;
147 : }
148 0 : return x / y;
149 : };
150 0 : const double m1 = SafeDivide(d * (c - lambda1) - e * f, f * (b - lambda1) - d * e);
151 0 : const double m2 = SafeDivide(d * (c - lambda2) - e * f, f * (b - lambda2) - d * e);
152 0 : const double m3 = SafeDivide(d * (c - lambda3) - e * f, f * (b - lambda3) - d * e);
153 0 : const double l1mcmem1 = lambda1 - c - e * m1;
154 0 : const double l2mcmem2 = lambda2 - c - e * m2;
155 0 : const double l3mcmem3 = lambda3 - c - e * m3;
156 0 : const double n1 = 1.0 + m1 * m1 + SafeDivide(std::pow(l1mcmem1, 2), f2);
157 0 : const double n2 = 1.0 + m2 * m2 + SafeDivide(std::pow(l2mcmem2, 2), f2);
158 0 : const double n3 = 1.0 + m3 * m3 + SafeDivide(std::pow(l3mcmem3, 2), f2);
159 0 : const double tlambda1 = functor(lambda1) / n1;
160 0 : const double tlambda2 = functor(lambda2) / n2;
161 0 : const double tlambda3 = functor(lambda3) / n3;
162 :
163 0 : const double at = (tlambda1 * l1mcmem1 * l1mcmem1 + tlambda2 * l2mcmem2 * l2mcmem2 +
164 0 : tlambda3 * l3mcmem3 * l3mcmem3) /
165 : f2;
166 0 : const double bt = tlambda1 * m1 * m1 + tlambda2 * m2 * m2 + tlambda3 * m3 * m3;
167 0 : const double ct = tlambda1 + tlambda2 + tlambda3;
168 0 : const double dt =
169 0 : (tlambda1 * m1 * l1mcmem1 + tlambda2 * m2 * l2mcmem2 + tlambda3 * m3 * l3mcmem3) /
170 0 : f;
171 0 : const double et = tlambda1 * m1 + tlambda2 * m2 + tlambda3 * m3;
172 0 : const double ft = (tlambda1 * l1mcmem1 + tlambda2 * l2mcmem2 + tlambda3 * l3mcmem3) / f;
173 0 : Mout(0, 0) = at;
174 0 : Mout(0, 1) = dt;
175 0 : Mout(0, 2) = ft;
176 0 : Mout(1, 0) = dt;
177 0 : Mout(1, 1) = bt;
178 0 : Mout(1, 2) = et;
179 0 : Mout(2, 0) = ft;
180 0 : Mout(2, 1) = et;
181 0 : Mout(2, 2) = ct;
182 0 : return Mout;
183 : }
184 : else
185 : {
186 0 : MFEM_ABORT("MatrixFunction only supports 2x2 or 3x3 matrices, N: " << N << "!");
187 : }
188 : return Mout;
189 0 : }
190 :
191 : } // namespace
192 :
193 : namespace linalg
194 : {
195 :
196 76 : mfem::DenseMatrix MatrixSqrt(const mfem::DenseMatrix &M)
197 : {
198 152 : return MatrixFunction(M, [](auto s) { return std::sqrt(s); });
199 : }
200 :
201 0 : mfem::DenseTensor MatrixSqrt(const mfem::DenseTensor &T)
202 : {
203 0 : mfem::DenseTensor S(T);
204 0 : mfem::DenseMatrix buffS, buffT;
205 0 : for (int k = 0; k < T.SizeK(); k++)
206 : {
207 0 : S(k, buffS) = MatrixSqrt(T(k, buffT));
208 : }
209 0 : return S;
210 0 : }
211 :
212 38 : mfem::DenseMatrix MatrixPow(const mfem::DenseMatrix &M, double p)
213 : {
214 76 : return MatrixFunction(M, [p](auto s) { return std::pow(s, p); });
215 : }
216 :
217 0 : mfem::DenseTensor MatrixPow(const mfem::DenseTensor &T, double p)
218 : {
219 0 : mfem::DenseTensor S(T);
220 0 : mfem::DenseMatrix buffS, buffT;
221 0 : for (int k = 0; k < T.SizeK(); k++)
222 : {
223 0 : S(k, buffS) = MatrixPow(T(k, buffT), p);
224 : }
225 0 : return S;
226 0 : }
227 :
228 38 : double SingularValueMax(const mfem::DenseMatrix &M)
229 : {
230 : MFEM_ASSERT(
231 : M.Height() == M.Width() && M.Height() > 0 && M.Height() <= 3,
232 : "Matrix singular values only available for square matrices of dimension <= 3!");
233 : const int N = M.Height();
234 38 : if (N == 1)
235 : {
236 0 : return M(0, 0);
237 : }
238 38 : else if (N == 2)
239 : {
240 0 : return mfem::kernels::CalcSingularvalue<2>(M.Data(), 0);
241 : }
242 : else
243 : {
244 38 : return mfem::kernels::CalcSingularvalue<3>(M.Data(), 0);
245 : }
246 : }
247 :
248 38 : double SingularValueMin(const mfem::DenseMatrix &M)
249 : {
250 : MFEM_ASSERT(
251 : M.Height() == M.Width() && M.Height() > 0 && M.Height() <= 3,
252 : "Matrix singular values only available for square matrices of dimension <= 3!");
253 : const int N = M.Height();
254 38 : if (N == 1)
255 : {
256 0 : return M(0, 0);
257 : }
258 38 : else if (N == 2)
259 : {
260 0 : return mfem::kernels::CalcSingularvalue<2>(M.Data(), 1);
261 : }
262 : else
263 : {
264 38 : return mfem::kernels::CalcSingularvalue<3>(M.Data(), 2);
265 : }
266 : }
267 :
268 0 : mfem::DenseTensor Mult(const mfem::DenseTensor &A, const mfem::DenseTensor &B)
269 : {
270 0 : MFEM_VERIFY(A.SizeK() == B.SizeK(),
271 : "Size mismatch for product of two DenseTensor objects!");
272 0 : mfem::DenseTensor C(A.SizeI(), B.SizeJ(), A.SizeK());
273 0 : mfem::DenseMatrix buffA, buffB, buffC;
274 0 : for (int k = 0; k < C.SizeK(); k++)
275 : {
276 0 : Mult(A(k, buffA), B(k, buffB), C(k, buffC));
277 : }
278 0 : return C;
279 0 : }
280 :
281 : } // namespace linalg
282 :
283 : } // namespace palace
|