LCOV - code coverage report
Current view: top level - linalg - hypre.hpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 100.0 % 6 6
Test Date: 2025-10-23 22:45:05 Functions: - 0 0
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #ifndef PALACE_LINALG_HYPRE_HPP
       5              : #define PALACE_LINALG_HYPRE_HPP
       6              : 
       7              : #include <mfem.hpp>
       8              : #include "linalg/operator.hpp"
       9              : #include "linalg/vector.hpp"
      10              : 
      11              : namespace palace::hypre
      12              : {
      13              : 
      14              : // Helper function to initialize HYPRE and control use of GPU at runtime. This will call
      15              : // HYPRE_SetMemoryLocation and HYPRE_SetExecutionPolicy to match the mfem::Device
      16              : // configuration.
      17              : inline void Initialize()
      18              : {
      19              :   mfem::Hypre::Init();
      20              :   // HYPRE_SetSpGemmUseCusparse(1);  // MFEM sets to zero, so leave as is for now
      21              : }
      22              : 
      23              : //
      24              : // Wrapper class for HYPRE's hypre_Vector, which can alias an mfem::Vector object for use
      25              : // with HYPRE.
      26              : //
      27              : class HypreVector
      28              : {
      29              : private:
      30              :   hypre_Vector *vec;
      31              : 
      32              : public:
      33              :   HypreVector(hypre_Vector *vec = nullptr);
      34              :   HypreVector(const Vector &x);
      35              :   ~HypreVector();
      36              : 
      37              :   auto Size() const { return hypre_VectorSize(vec); }
      38              : 
      39              :   void Update(const Vector &x);
      40              : 
      41           47 :   operator hypre_Vector *() const { return vec; }
      42              : };
      43              : 
      44              : //
      45              : // Wrapper class for HYPRE's hypre_CSRMatrix, an alternative to mfem::SparseMatrix with
      46              : // increased functionality from HYPRE.
      47              : //
      48              : class HypreCSRMatrix : public palace::Operator
      49              : {
      50              : private:
      51              :   hypre_CSRMatrix *mat;
      52              :   mfem::Array<HYPRE_Int> data_I, data_J;
      53              :   bool hypre_own_I;
      54              : 
      55              : public:
      56              :   HypreCSRMatrix(int h, int w, int nnz);
      57              :   HypreCSRMatrix(hypre_CSRMatrix *mat);
      58              :   HypreCSRMatrix(const mfem::SparseMatrix &m);
      59              :   ~HypreCSRMatrix();
      60              : 
      61          246 :   auto NNZ() const { return hypre_CSRMatrixNumNonzeros(mat); }
      62              : 
      63              :   const auto *GetI() const { return hypre_CSRMatrixI(mat); }
      64        23562 :   auto *GetI() { return hypre_CSRMatrixI(mat); }
      65              :   const auto *GetJ() const { return hypre_CSRMatrixJ(mat); }
      66        23562 :   auto *GetJ() { return hypre_CSRMatrixJ(mat); }
      67              :   const auto *GetData() const { return hypre_CSRMatrixData(mat); }
      68        23766 :   auto *GetData() { return hypre_CSRMatrixData(mat); }
      69              : 
      70              :   void AssembleDiagonal(Vector &diag) const override;
      71              : 
      72              :   void Mult(const Vector &x, Vector &y) const override;
      73              : 
      74              :   void AddMult(const Vector &x, Vector &y, const double a = 1.0) const override;
      75              : 
      76              :   void MultTranspose(const Vector &x, Vector &y) const override;
      77              : 
      78              :   void AddMultTranspose(const Vector &x, Vector &y, const double a = 1.0) const override;
      79              : 
      80        22770 :   operator hypre_CSRMatrix *() const { return mat; }
      81              : };
      82              : 
      83              : }  // namespace palace::hypre
      84              : 
      85              : #endif  // PALACE_LINALG_HYPRE_HPP
        

Generated by: LCOV version 2.0-1