Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #ifndef PALACE_LINALG_HYPRE_HPP
5 : #define PALACE_LINALG_HYPRE_HPP
6 :
7 : #include <mfem.hpp>
8 : #include "linalg/operator.hpp"
9 : #include "linalg/vector.hpp"
10 :
11 : namespace palace::hypre
12 : {
13 :
14 : // Helper function to initialize HYPRE and control use of GPU at runtime. This will call
15 : // HYPRE_SetMemoryLocation and HYPRE_SetExecutionPolicy to match the mfem::Device
16 : // configuration.
17 : inline void Initialize()
18 : {
19 : mfem::Hypre::Init();
20 : // HYPRE_SetSpGemmUseCusparse(1); // MFEM sets to zero, so leave as is for now
21 : }
22 :
23 : //
24 : // Wrapper class for HYPRE's hypre_Vector, which can alias an mfem::Vector object for use
25 : // with HYPRE.
26 : //
27 : class HypreVector
28 : {
29 : private:
30 : hypre_Vector *vec;
31 :
32 : public:
33 : HypreVector(hypre_Vector *vec = nullptr);
34 : HypreVector(const Vector &x);
35 : ~HypreVector();
36 :
37 : auto Size() const { return hypre_VectorSize(vec); }
38 :
39 : void Update(const Vector &x);
40 :
41 47 : operator hypre_Vector *() const { return vec; }
42 : };
43 :
44 : //
45 : // Wrapper class for HYPRE's hypre_CSRMatrix, an alternative to mfem::SparseMatrix with
46 : // increased functionality from HYPRE.
47 : //
48 : class HypreCSRMatrix : public palace::Operator
49 : {
50 : private:
51 : hypre_CSRMatrix *mat;
52 : mfem::Array<HYPRE_Int> data_I, data_J;
53 : bool hypre_own_I;
54 :
55 : public:
56 : HypreCSRMatrix(int h, int w, int nnz);
57 : HypreCSRMatrix(hypre_CSRMatrix *mat);
58 : HypreCSRMatrix(const mfem::SparseMatrix &m);
59 : ~HypreCSRMatrix();
60 :
61 246 : auto NNZ() const { return hypre_CSRMatrixNumNonzeros(mat); }
62 :
63 : const auto *GetI() const { return hypre_CSRMatrixI(mat); }
64 23562 : auto *GetI() { return hypre_CSRMatrixI(mat); }
65 : const auto *GetJ() const { return hypre_CSRMatrixJ(mat); }
66 23562 : auto *GetJ() { return hypre_CSRMatrixJ(mat); }
67 : const auto *GetData() const { return hypre_CSRMatrixData(mat); }
68 23766 : auto *GetData() { return hypre_CSRMatrixData(mat); }
69 :
70 : void AssembleDiagonal(Vector &diag) const override;
71 :
72 : void Mult(const Vector &x, Vector &y) const override;
73 :
74 : void AddMult(const Vector &x, Vector &y, const double a = 1.0) const override;
75 :
76 : void MultTranspose(const Vector &x, Vector &y) const override;
77 :
78 : void AddMultTranspose(const Vector &x, Vector &y, const double a = 1.0) const override;
79 :
80 22770 : operator hypre_CSRMatrix *() const { return mat; }
81 : };
82 :
83 : } // namespace palace::hypre
84 :
85 : #endif // PALACE_LINALG_HYPRE_HPP
|