Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #include "hypre.hpp"
5 :
6 : namespace palace::hypre
7 : {
8 :
9 146 : HypreVector::HypreVector(hypre_Vector *vec) : vec(vec) {}
10 :
11 0 : HypreVector::HypreVector(const Vector &x) : vec(nullptr)
12 : {
13 0 : Update(x);
14 0 : }
15 :
16 146 : HypreVector::~HypreVector()
17 : {
18 146 : hypre_SeqVectorDestroy(vec);
19 146 : }
20 :
21 83 : void HypreVector::Update(const Vector &x)
22 : {
23 : const HYPRE_Int N = x.Size();
24 83 : if (!vec)
25 : {
26 14 : vec = hypre_SeqVectorCreate(N);
27 14 : hypre_SeqVectorSetDataOwner(vec, 0);
28 14 : hypre_VectorData(vec) = const_cast<double *>(x.Read());
29 14 : hypre_SeqVectorInitialize(vec);
30 : }
31 : else
32 : {
33 69 : hypre_SeqVectorSetSize(vec, N);
34 69 : hypre_VectorData(vec) = const_cast<double *>(x.Read());
35 : }
36 83 : }
37 :
38 21900 : HypreCSRMatrix::HypreCSRMatrix(int h, int w, int nnz)
39 21900 : : palace::Operator(h, w), hypre_own_I(true)
40 : {
41 21900 : mat = hypre_CSRMatrixCreate(h, w, nnz);
42 21900 : hypre_CSRMatrixInitialize(mat);
43 21900 : }
44 :
45 11394 : HypreCSRMatrix::HypreCSRMatrix(hypre_CSRMatrix *mat) : mat(mat), hypre_own_I(true)
46 : {
47 11394 : height = hypre_CSRMatrixNumRows(mat);
48 11394 : width = hypre_CSRMatrixNumCols(mat);
49 11394 : }
50 :
51 0 : HypreCSRMatrix::HypreCSRMatrix(const mfem::SparseMatrix &m)
52 0 : : palace::Operator(m.Height(), m.Width()), hypre_own_I(false)
53 : {
54 0 : const int nnz = m.NumNonZeroElems();
55 0 : mat = hypre_CSRMatrixCreate(height, width, nnz);
56 0 : hypre_CSRMatrixSetDataOwner(mat, 0);
57 0 : hypre_CSRMatrixData(mat) = const_cast<double *>(m.ReadData());
58 : #if !defined(HYPRE_BIGINT)
59 0 : hypre_CSRMatrixI(mat) = const_cast<int *>(m.ReadI());
60 0 : hypre_CSRMatrixJ(mat) = const_cast<int *>(m.ReadJ());
61 : #else
62 : data_I.SetSize(height);
63 : data_J.SetSize(nnz);
64 : {
65 : const auto *I = m.ReadI();
66 : const auto *J = m.ReadJ();
67 : auto *DI = data_I.Write();
68 : auto *DJ = data_J.Write();
69 : mfem::forall(height, [=] MFEM_HOST_DEVICE(int i) { DI[i] = I[i]; });
70 : mfem::forall(nnz, [=] MFEM_HOST_DEVICE(int i) { DJ[i] = J[i]; });
71 : }
72 : #endif
73 0 : hypre_CSRMatrixInitialize(mat);
74 0 : }
75 :
76 66588 : HypreCSRMatrix::~HypreCSRMatrix()
77 : {
78 33294 : if (!hypre_own_I)
79 : {
80 0 : hypre_CSRMatrixI(mat) = nullptr;
81 : }
82 33294 : hypre_CSRMatrixDestroy(mat);
83 66588 : }
84 :
85 0 : void HypreCSRMatrix::AssembleDiagonal(Vector &diag) const
86 : {
87 0 : diag.SetSize(height);
88 0 : hypre_CSRMatrixExtractDiagonal(mat, diag.Write(), 0);
89 0 : }
90 :
91 : namespace
92 : {
93 :
94 : static HypreVector X, Y;
95 :
96 : } // namespace
97 :
98 0 : void HypreCSRMatrix::Mult(const Vector &x, Vector &y) const
99 : {
100 0 : X.Update(x);
101 0 : Y.Update(y);
102 0 : hypre_CSRMatrixMatvec(1.0, mat, X, 0.0, Y);
103 0 : }
104 :
105 0 : void HypreCSRMatrix::AddMult(const Vector &x, Vector &y, const double a) const
106 : {
107 0 : X.Update(x);
108 0 : Y.Update(y);
109 0 : hypre_CSRMatrixMatvec(a, mat, X, 1.0, Y);
110 0 : }
111 :
112 0 : void HypreCSRMatrix::MultTranspose(const Vector &x, Vector &y) const
113 : {
114 0 : X.Update(x);
115 0 : Y.Update(y);
116 0 : hypre_CSRMatrixMatvecT(1.0, mat, X, 0.0, Y);
117 0 : }
118 :
119 0 : void HypreCSRMatrix::AddMultTranspose(const Vector &x, Vector &y, const double a) const
120 : {
121 0 : X.Update(x);
122 0 : Y.Update(y);
123 0 : hypre_CSRMatrixMatvecT(a, mat, X, 1.0, Y);
124 0 : }
125 :
126 : } // namespace palace::hypre
|