Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #include "strumpack.hpp"
5 :
6 : #if defined(MFEM_USE_STRUMPACK)
7 :
8 : namespace palace
9 : {
10 :
11 : namespace
12 : {
13 :
14 : strumpack::CompressionType GetCompressionType(SparseCompression type)
15 : {
16 0 : switch (type)
17 : {
18 : case SparseCompression::HSS:
19 : return strumpack::CompressionType::HSS;
20 : case SparseCompression::BLR:
21 : return strumpack::CompressionType::BLR;
22 : case SparseCompression::HODLR:
23 : return strumpack::CompressionType::HODLR;
24 : case SparseCompression::ZFP:
25 : return strumpack::CompressionType::LOSSY;
26 : case SparseCompression::BLR_HODLR:
27 : return strumpack::CompressionType::BLR_HODLR;
28 : break;
29 : case SparseCompression::ZFP_BLR_HODLR:
30 : return strumpack::CompressionType::ZFP_BLR_HODLR;
31 : break;
32 : case SparseCompression::NONE:
33 : return strumpack::CompressionType::NONE;
34 : }
35 : return strumpack::CompressionType::NONE; // For compiler warning
36 : }
37 :
38 : } // namespace
39 :
40 : template <typename StrumpackSolverType>
41 0 : StrumpackSolverBase<StrumpackSolverType>::StrumpackSolverBase(
42 : MPI_Comm comm, SymbolicFactorization reorder, SparseCompression compression,
43 : double lr_tol, int butterfly_l, int lossy_prec, int print)
44 0 : : StrumpackSolverType(comm), comm(comm)
45 : {
46 : // Configure the solver.
47 0 : this->SetPrintFactorStatistics(print > 1);
48 0 : this->SetPrintSolveStatistics(print > 1);
49 0 : this->SetKrylovSolver(strumpack::KrylovSolver::DIRECT); // Always as a preconditioner or
50 : // direct solver
51 0 : this->SetMatching(strumpack::MatchingJob::NONE);
52 0 : switch (reorder)
53 : {
54 0 : case SymbolicFactorization::METIS:
55 0 : this->SetReorderingStrategy(strumpack::ReorderingStrategy::METIS);
56 : // this->SetReorderingStrategy(strumpack::ReorderingStrategy::AND);
57 : break;
58 0 : case SymbolicFactorization::PARMETIS:
59 0 : this->SetReorderingStrategy(strumpack::ReorderingStrategy::PARMETIS);
60 : break;
61 0 : case SymbolicFactorization::SCOTCH:
62 0 : this->SetReorderingStrategy(strumpack::ReorderingStrategy::SCOTCH);
63 : break;
64 0 : case SymbolicFactorization::PTSCOTCH:
65 0 : this->SetReorderingStrategy(strumpack::ReorderingStrategy::PTSCOTCH);
66 : break;
67 0 : case SymbolicFactorization::AMD:
68 0 : this->SetReorderingStrategy(strumpack::ReorderingStrategy::AMD);
69 : // this->SetReorderingStrategy(strumpack::ReorderingStrategy::MMD);
70 : break;
71 0 : case SymbolicFactorization::RCM:
72 0 : this->SetReorderingStrategy(strumpack::ReorderingStrategy::RCM);
73 : case SymbolicFactorization::PORD:
74 : case SymbolicFactorization::DEFAULT:
75 : // Should have good default.
76 : break;
77 : }
78 0 : this->SetReorderingReuse(true); // Repeated calls use same sparsity pattern
79 :
80 : // Configure compression.
81 0 : this->SetCompression(GetCompressionType(compression));
82 0 : switch (compression)
83 : {
84 0 : case SparseCompression::ZFP:
85 0 : if (lossy_prec <= 0)
86 : {
87 0 : this->SetCompression(strumpack::CompressionType::LOSSLESS);
88 : }
89 : else
90 : {
91 0 : this->SetCompressionLossyPrecision(lossy_prec);
92 : }
93 : break;
94 0 : case SparseCompression::ZFP_BLR_HODLR:
95 0 : this->SetCompressionLossyPrecision(lossy_prec);
96 0 : case SparseCompression::HODLR:
97 : case SparseCompression::BLR_HODLR:
98 0 : this->SetCompressionButterflyLevels(butterfly_l);
99 0 : case SparseCompression::HSS:
100 : case SparseCompression::BLR:
101 0 : this->SetCompressionRelTol(lr_tol);
102 : break;
103 : case SparseCompression::NONE:
104 : break;
105 : }
106 : // if (mfem::Device::Allows(mfem::Backend::DEVICE_MASK))
107 : // {
108 : // this->EnableGPU(); // XX TODO: GPU support disabled for now
109 : // }
110 : // else
111 : {
112 0 : this->DisableGPU();
113 : }
114 0 : }
115 :
116 : template <typename StrumpackSolverType>
117 0 : void StrumpackSolverBase<StrumpackSolverType>::SetOperator(const Operator &op)
118 : {
119 : // Convert the input operator to a distributed STRUMPACK matrix (always assume a symmetric
120 : // sparsity pattern). This is very similar to the MFEM's STRUMPACKRowLocMatrix from a
121 : // HypreParMatrix but avoids using the communicator from the Hypre matrix in the case that
122 : // the solver is constructed on a different communicator.
123 0 : const auto *hA = dynamic_cast<const mfem::HypreParMatrix *>(&op);
124 0 : MFEM_VERIFY(hA && hA->GetGlobalNumRows() == hA->GetGlobalNumCols(),
125 : "StrumpackSolver requires a square HypreParMatrix operator!");
126 : auto *parcsr = (hypre_ParCSRMatrix *)const_cast<mfem::HypreParMatrix &>(*hA);
127 0 : hypre_CSRMatrix *csr = hypre_MergeDiagAndOffd(parcsr);
128 0 : hypre_CSRMatrixMigrate(csr, HYPRE_MEMORY_HOST);
129 :
130 : // Create the STRUMPACKRowLocMatrix by taking the internal data from a hypre_CSRMatrix.
131 0 : HYPRE_BigInt glob_n = hypre_ParCSRMatrixGlobalNumRows(parcsr);
132 0 : HYPRE_BigInt first_row = hypre_ParCSRMatrixFirstRowIndex(parcsr);
133 0 : HYPRE_Int n_loc = hypre_CSRMatrixNumRows(csr);
134 0 : HYPRE_Int *I = hypre_CSRMatrixI(csr);
135 0 : HYPRE_BigInt *J = hypre_CSRMatrixBigJ(csr);
136 0 : double *data = hypre_CSRMatrixData(csr);
137 :
138 : // Safe to delete the matrix since STRUMPACK copies it on input. Also clean up the Hypre
139 : // data structure once we are done with it.
140 : #if !defined(HYPRE_BIGINT)
141 0 : mfem::STRUMPACKRowLocMatrix A(comm, n_loc, first_row, glob_n, glob_n, I, J, data, true);
142 : #else
143 : int n_loc_int = static_cast<int>(n_loc);
144 : MFEM_ASSERT(n_loc == (HYPRE_Int)n_loc_int,
145 : "Overflow error for local sparse matrix size!");
146 : mfem::Array<int> II(n_loc_int + 1);
147 : for (int i = 0; i <= n_loc_int; i++)
148 : {
149 : II[i] = static_cast<int>(I[i]);
150 : MFEM_ASSERT(I[i] == (HYPRE_Int)II[i], "Overflow error for local sparse matrix index!");
151 : }
152 : mfem::STRUMPACKRowLocMatrix A(comm, n_loc_int, first_row, glob_n, glob_n, II.HostRead(),
153 : J, data, true);
154 : #endif
155 0 : StrumpackSolverType::SetOperator(A);
156 0 : hypre_CSRMatrixDestroy(csr);
157 0 : }
158 :
159 : template class StrumpackSolverBase<mfem::STRUMPACKSolver>;
160 : template class StrumpackSolverBase<mfem::STRUMPACKMixedPrecisionSolver>;
161 :
162 : } // namespace palace
163 :
164 : #endif
|