Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #ifndef PALACE_LINALG_ORTHOG_HPP
5 : #define PALACE_LINALG_ORTHOG_HPP
6 :
7 : #include <vector>
8 : #include "linalg/vector.hpp"
9 : #include "utils/communication.hpp"
10 :
11 : namespace palace::linalg
12 : {
13 :
14 : //
15 : // Orthogonalization functions for orthogonalizing a vector against a number of basis
16 : // vectors using modified or classical Gram-Schmidt.
17 : //
18 :
19 : template <typename VecType, typename ScalarType>
20 0 : inline void OrthogonalizeColumnMGS(MPI_Comm comm, const std::vector<VecType> &V, VecType &w,
21 : ScalarType *H, int m)
22 : {
23 : MFEM_ASSERT(static_cast<std::size_t>(m) <= V.size(),
24 : "Out of bounds number of columns for MGS orthogonalization!");
25 0 : for (int j = 0; j < m; j++)
26 : {
27 0 : H[j] = linalg::Dot(comm, w, V[j]); // Global inner product
28 0 : w.Add(-H[j], V[j]);
29 : }
30 0 : }
31 :
32 : template <typename VecType, typename ScalarType>
33 0 : inline void OrthogonalizeColumnCGS(MPI_Comm comm, const std::vector<VecType> &V, VecType &w,
34 : ScalarType *H, int m, bool refine = false)
35 : {
36 : MFEM_ASSERT(static_cast<std::size_t>(m) <= V.size(),
37 : "Out of bounds number of columns for CGS orthogonalization!");
38 0 : if (m == 0)
39 : {
40 : return;
41 : }
42 0 : for (int j = 0; j < m; j++)
43 : {
44 0 : H[j] = w * V[j]; // Local inner product
45 : }
46 : Mpi::GlobalSum(m, H, comm);
47 0 : for (int j = 0; j < m; j++)
48 : {
49 0 : w.Add(-H[j], V[j]);
50 : }
51 0 : if (refine)
52 : {
53 0 : std::vector<ScalarType> dH(m);
54 0 : for (int j = 0; j < m; j++)
55 : {
56 0 : dH[j] = w * V[j]; // Local inner product
57 : }
58 : Mpi::GlobalSum(m, dH.data(), comm);
59 0 : for (int j = 0; j < m; j++)
60 : {
61 0 : H[j] += dH[j];
62 0 : w.Add(-dH[j], V[j]);
63 : }
64 : }
65 : }
66 :
67 : } // namespace palace::linalg
68 :
69 : #endif // PALACE_LINALG_ORTHOG_HPP
|