Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #ifndef PALACE_LINALG_VECTOR_HPP
5 : #define PALACE_LINALG_VECTOR_HPP
6 :
7 : #include <complex>
8 : #include <vector>
9 : #include <mfem.hpp>
10 : #include "utils/communication.hpp"
11 :
12 : namespace palace
13 : {
14 :
15 : //
16 : // Functionality extending mfem::Vector from MFEM, including basic functions for parallel
17 : // vectors distributed across MPI processes.
18 : //
19 :
20 : using Vector = mfem::Vector;
21 :
22 : // A complex-valued vector represented as two real vectors, one for each component.
23 61941 : class ComplexVector
24 : {
25 : private:
26 : Vector xr, xi;
27 :
28 : public:
29 : // Create a vector with the given size.
30 : ComplexVector(int size = 0);
31 :
32 : // Copy constructor.
33 : ComplexVector(const ComplexVector &y);
34 :
35 : // Copy constructor from separately provided real and imaginary parts.
36 : ComplexVector(const Vector &yr, const Vector &yi);
37 :
38 : // Copy constructor from an array of complex values.
39 : ComplexVector(const std::complex<double> *py, int size, bool on_dev);
40 :
41 : // Create a vector referencing the memory of another vector, at the given base offset and
42 : // size.
43 : ComplexVector(Vector &y, int offset, int size);
44 :
45 : // Flag for runtime execution on the mfem::Device. See the documentation for mfem::Vector.
46 : void UseDevice(bool use_dev);
47 : bool UseDevice() const { return xr.UseDevice(); }
48 :
49 : // Return the size of the vector.
50 : int Size() const { return xr.Size(); }
51 :
52 : // Set the size of the vector. See the notes for Vector::SetSize for behavior in the cases
53 : // where the new size is less than or greater than Size() or Capacity().
54 : void SetSize(int size);
55 :
56 : // Set this vector to reference the memory of another vector, at the given base offset and
57 : // size.
58 : void MakeRef(Vector &y, int offset, int size);
59 :
60 : // Get access to the real and imaginary vector parts.
61 33 : const Vector &Real() const { return xr; }
62 66 : Vector &Real() { return xr; }
63 33 : const Vector &Imag() const { return xi; }
64 60 : Vector &Imag() { return xi; }
65 :
66 : // Set from a ComplexVector, without resizing.
67 : void Set(const ComplexVector &y);
68 : ComplexVector &operator=(const ComplexVector &y)
69 : {
70 0 : Set(y);
71 0 : return *this;
72 : }
73 :
74 : // Set from separately provided real and imaginary parts, without resizing.
75 : void Set(const Vector &yr, const Vector &yi);
76 :
77 : // Set from an array of complex values, without resizing.
78 : void Set(const std::complex<double> *py, int size, bool on_dev);
79 :
80 : // Copy the vector into an array of complex values.
81 : void Get(std::complex<double> *py, int size, bool on_dev) const;
82 :
83 : // Set all entries equal to s.
84 : ComplexVector &operator=(std::complex<double> s);
85 : ComplexVector &operator=(double s)
86 : {
87 9 : *this = std::complex<double>(s, 0.0);
88 0 : return *this;
89 : }
90 :
91 : // Set the vector from an array of blocks and coefficients, without resizing.
92 : void SetBlocks(const std::vector<const ComplexVector *> &y,
93 : const std::vector<std::complex<double>> &s);
94 :
95 : // Scale all entries by s.
96 : ComplexVector &operator*=(std::complex<double> s);
97 :
98 : // Replace entries with their complex conjugate.
99 : void Conj();
100 :
101 : // Replace entries with their absolute value.
102 : void Abs();
103 :
104 : // Set all entries to their reciprocal.
105 : void Reciprocal();
106 :
107 : // Vector dot product (yᴴ x) or indefinite dot product (yᵀ x) for complex vectors.
108 : std::complex<double> Dot(const ComplexVector &y) const;
109 : std::complex<double> TransposeDot(const ComplexVector &y) const;
110 0 : std::complex<double> operator*(const ComplexVector &y) const { return Dot(y); }
111 :
112 : // In-place addition (*this) += alpha * x.
113 : void AXPY(std::complex<double> alpha, const ComplexVector &x);
114 0 : void Add(std::complex<double> alpha, const ComplexVector &x) { AXPY(alpha, x); }
115 : void Subtract(std::complex<double> alpha, const ComplexVector &x) { AXPY(-alpha, x); }
116 : ComplexVector &operator+=(const ComplexVector &x)
117 : {
118 0 : AXPY(1.0, x);
119 0 : return *this;
120 : }
121 : ComplexVector &operator-=(const ComplexVector &x)
122 : {
123 3 : AXPY(-1.0, x);
124 : return *this;
125 : }
126 :
127 : // In-place addition (*this) = alpha * x + beta * (*this).
128 : void AXPBY(std::complex<double> alpha, const ComplexVector &x, std::complex<double> beta);
129 :
130 : // In-place addition (*this) = alpha * x + beta * y + gamma * (*this).
131 : void AXPBYPCZ(std::complex<double> alpha, const ComplexVector &x,
132 : std::complex<double> beta, const ComplexVector &y,
133 : std::complex<double> gamma);
134 :
135 : static void AXPY(std::complex<double> alpha, const Vector &xr, const Vector &xi,
136 : Vector &yr, Vector &yi);
137 :
138 : static void AXPBY(std::complex<double> alpha, const Vector &xr, const Vector &xi,
139 : std::complex<double> beta, Vector &yr, Vector &yi);
140 :
141 : static void AXPBYPCZ(std::complex<double> alpha, const Vector &xr, const Vector &xi,
142 : std::complex<double> beta, const Vector &yr, const Vector &yi,
143 : std::complex<double> gamma, Vector &zr, Vector &zi);
144 : };
145 :
146 : // A stack-allocated vector with compile-time fixed size.
147 : //
148 : // StaticVector provides a Vector interface backed by stack memory instead of
149 : // heap allocation. The size N is fixed at compile time, making it suitable for
150 : // small vectors where performance and avoiding dynamic allocation are
151 : // important.
152 : //
153 : // Template parameters:
154 : // - N: The fixed size of the vector (number of elements)
155 : //
156 : // Notes:
157 : // - Inherits from mfem::Vector, so can be used anywhere Vector is expected.
158 : // - Memory is automatically managed (no new/delete needed).
159 : // - Faster than dynamic Vector for small sizes due to stack allocation.
160 : //
161 : // Example usage:
162 : //
163 : // StaticVector<3> vec; // 3D vector on stack
164 : // vec[0] = 1.0;
165 : // vec[1] = 2.0;
166 : // vec[2] = 3.0;
167 : //
168 : // vec.Sum();
169 : //
170 : // You can also create StaticComplexVectors:
171 : //
172 : // StaticVector<3> vec_real, vec_imag;
173 : // ComplexVector complex_vec(vec_real, vec_imag);
174 : template <int N>
175 : class StaticVector : public Vector
176 : {
177 : private:
178 : double buff[N];
179 :
180 : public:
181 115214 : StaticVector() : Vector() { SetDataAndSize(buff, N); }
182 :
183 0 : ~StaticVector()
184 : {
185 : MFEM_ASSERT(GetData() == buff,
186 : "Buffer of StaticVector changed. This indicates a possible bug.");
187 : MFEM_ASSERT(Size() == N, "Size of StaticVector changed. This indicates a possible bug.")
188 9609 : }
189 :
190 : using Vector::operator=; // Extend the implicitly defined assignment operators
191 : };
192 :
193 : namespace linalg
194 : {
195 :
196 : // Returns the global vector size.
197 : template <typename VecType>
198 : inline HYPRE_BigInt GlobalSize(MPI_Comm comm, const VecType &x)
199 : {
200 : HYPRE_BigInt N = x.Size();
201 : Mpi::GlobalSum(1, &N, comm);
202 : return N;
203 : }
204 :
205 : // Returns the global vector size for two vectors.
206 : template <typename VecType1, typename VecType2>
207 0 : inline std::pair<HYPRE_BigInt, HYPRE_BigInt> GlobalSize2(MPI_Comm comm, const VecType1 &x1,
208 : const VecType2 &x2)
209 : {
210 0 : HYPRE_BigInt N[2] = {x1.Size(), x2.Size()};
211 : Mpi::GlobalSum(2, N, comm);
212 0 : return {N[0], N[1]};
213 : }
214 :
215 : // Sets all entries of the vector corresponding to the given indices to the given (real)
216 : // value or vector of values.
217 : template <typename VecType>
218 : void SetSubVector(VecType &x, const mfem::Array<int> &rows, double s);
219 : template <typename VecType>
220 : void SetSubVector(VecType &x, const mfem::Array<int> &rows, const VecType &y);
221 :
222 : // Sets contiguous entries from start to the given vector.
223 : template <typename VecType>
224 : void SetSubVector(VecType &x, int start, const VecType &y);
225 :
226 : // Sets all entries in the range [start, end) to the given value.
227 : template <typename VecType>
228 : void SetSubVector(VecType &x, int start, int end, double s);
229 :
230 : // Sets all entries of the vector to random numbers sampled from the [-1, 1] or [-1 - 1i,
231 : // 1 + 1i] for complex-valued vectors.
232 : template <typename VecType>
233 : void SetRandom(MPI_Comm comm, VecType &x, int seed = 0);
234 : template <typename VecType>
235 : void SetRandomReal(MPI_Comm comm, VecType &x, int seed = 0);
236 : template <typename VecType>
237 : void SetRandomSign(MPI_Comm comm, VecType &x, int seed = 0);
238 :
239 : // Calculate the local inner product yᴴ x or yᵀ x.
240 : double LocalDot(const Vector &x, const Vector &y);
241 : std::complex<double> LocalDot(const ComplexVector &x, const ComplexVector &y);
242 :
243 : // Calculate the parallel inner product yᴴ x or yᵀ x.
244 : template <typename VecType>
245 0 : inline auto Dot(MPI_Comm comm, const VecType &x, const VecType &y)
246 : {
247 0 : auto dot = LocalDot(x, y);
248 : Mpi::GlobalSum(1, &dot, comm);
249 0 : return dot;
250 : }
251 :
252 : // Calculate the vector 2-norm.
253 : template <typename VecType>
254 0 : inline auto Norml2(MPI_Comm comm, const VecType &x)
255 : {
256 0 : return std::sqrt(std::abs(Dot(comm, x, x)));
257 : }
258 :
259 : // Normalize the vector, possibly with respect to an SPD matrix B.
260 : template <typename VecType>
261 : inline auto Normalize(MPI_Comm comm, VecType &x)
262 : {
263 : auto norm = Norml2(comm, x);
264 : MFEM_ASSERT(norm > 0.0, "Zero vector norm in normalization!");
265 : x *= 1.0 / norm;
266 : return norm;
267 : }
268 :
269 : // Calculate the local sum of all elements in the vector.
270 : double LocalSum(const Vector &x);
271 : std::complex<double> LocalSum(const ComplexVector &x);
272 :
273 : // Calculate the sum of all elements in the vector.
274 : template <typename VecType>
275 8 : inline auto Sum(MPI_Comm comm, const VecType &x)
276 : {
277 8 : auto sum = LocalSum(x);
278 : Mpi::GlobalSum(1, &sum, comm);
279 8 : return sum;
280 : }
281 :
282 : // Calculate the mean of all elements in the vector.
283 : template <typename VecType>
284 0 : inline auto Mean(MPI_Comm comm, const VecType &x)
285 : {
286 : using ScalarType = typename std::conditional<std::is_same<VecType, ComplexVector>::value,
287 : std::complex<double>, double>::type;
288 0 : ScalarType sum[2] = {LocalSum(x), ScalarType(x.Size())};
289 : Mpi::GlobalSum(2, sum, comm);
290 0 : return sum[0] / sum[1];
291 : }
292 :
293 : // Normalize a complex vector so its mean is on the positive real axis.
294 : // Returns the original mean phase.
295 0 : inline double NormalizePhase(MPI_Comm comm, ComplexVector &x)
296 : {
297 0 : std::complex<double> mean = Mean(comm, x);
298 0 : x *= std::conj(mean) / std::abs(mean);
299 0 : return std::atan2(mean.imag(), mean.real());
300 : }
301 :
302 : // Addition y += alpha * x.
303 : template <typename VecType, typename ScalarType>
304 : void AXPY(ScalarType alpha, const VecType &x, VecType &y);
305 :
306 : // Addition y = alpha * x + beta * y.
307 : template <typename VecType, typename ScalarType>
308 : void AXPBY(ScalarType alpha, const VecType &x, ScalarType beta, VecType &y);
309 :
310 : // Addition z = alpha * x + beta * y + gamma * z.
311 : template <typename VecType, typename ScalarType>
312 : void AXPBYPCZ(ScalarType alpha, const VecType &x, ScalarType beta, const VecType &y,
313 : ScalarType gamma, VecType &z);
314 :
315 : // Compute element-wise square root, optionally with scaling (multiplied before the square
316 : // root).
317 : void Sqrt(Vector &x, double s = 1.0);
318 :
319 : // Compute the 3D Cartesian product between A and B and store the result in C.
320 : // If add is true, accumulate the result to C instead of overwriting its
321 : // content.
322 : template <typename VecTypeA, typename VecTypeB, typename VecTypeC>
323 142953 : void Cross3(const VecTypeA &A, const VecTypeB &B, VecTypeC &C, bool add = false)
324 : {
325 142953 : if (add)
326 : {
327 7777 : C[0] += A[1] * B[2] - A[2] * B[1];
328 7777 : C[1] += A[2] * B[0] - A[0] * B[2];
329 7777 : C[2] += A[0] * B[1] - A[1] * B[0];
330 : }
331 : else
332 : {
333 135176 : C[0] = A[1] * B[2] - A[2] * B[1];
334 135176 : C[1] = A[2] * B[0] - A[0] * B[2];
335 135176 : C[2] = A[0] * B[1] - A[1] * B[0];
336 : }
337 142953 : }
338 :
339 : } // namespace linalg
340 :
341 : } // namespace palace
342 :
343 : #endif // PALACE_LINALG_VECTOR_HPP
|