LCOV - code coverage report
Current view: top level - fem/libceed - restriction.cpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 53.2 % 94 50
Test Date: 2025-10-23 22:45:05 Functions: 66.7 % 6 4
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #include "restriction.hpp"
       5              : 
       6              : #include <mfem.hpp>
       7              : #include "utils/omp.hpp"
       8              : 
       9              : namespace palace::ceed
      10              : {
      11              : 
      12              : namespace
      13              : {
      14              : 
      15            0 : const mfem::FiniteElement *GetTraceElement(const mfem::FiniteElementSpace &fespace,
      16              :                                            const std::vector<int> &indices)
      17              : {
      18              :   int elem_id, face_info;
      19            0 :   fespace.GetMesh()->GetBdrElementAdjacentElement(indices[0], elem_id, face_info);
      20            0 :   mfem::Geometry::Type face_geom = fespace.GetMesh()->GetBdrElementGeometry(indices[0]);
      21            0 :   return fespace.GetTraceElement(elem_id, face_geom);
      22              : };
      23              : 
      24            0 : mfem::Array<int> GetFaceDofsFromAdjacentElement(const mfem::FiniteElementSpace &fespace,
      25              :                                                 mfem::DofTransformation &dof_trans,
      26              :                                                 const int P, const int e)
      27              : {
      28              :   // Get coordinates of face dofs.
      29              :   int elem_id, face_info;
      30            0 :   fespace.GetMesh()->GetBdrElementAdjacentElement(e, elem_id, face_info);
      31              :   mfem::Geometry::Type face_geom = fespace.GetMesh()->GetBdrElementGeometry(e);
      32            0 :   face_info = fespace.GetMesh()->EncodeFaceInfo(
      33              :       fespace.GetMesh()->DecodeFaceInfoLocalIndex(face_info),
      34              :       mfem::Geometry::GetInverseOrientation(
      35              :           face_geom, fespace.GetMesh()->DecodeFaceInfoOrientation(face_info)));
      36              :   mfem::IntegrationPointTransformation Loc1;
      37            0 :   fespace.GetMesh()->GetLocalFaceTransformation(fespace.GetMesh()->GetBdrElementType(e),
      38            0 :                                                 fespace.GetMesh()->GetElementType(elem_id),
      39              :                                                 Loc1.Transf, face_info);
      40            0 :   const mfem::FiniteElement *face_el = fespace.GetTraceElement(elem_id, face_geom);
      41            0 :   MFEM_VERIFY(dynamic_cast<const mfem::NodalFiniteElement *>(face_el),
      42              :               "Mesh requires nodal Finite Element.");
      43            0 :   mfem::IntegrationRule face_ir(face_el->GetDof());
      44            0 :   Loc1.Transf.ElementNo = elem_id;
      45            0 :   Loc1.Transf.mesh = fespace.GetMesh();
      46            0 :   Loc1.Transf.ElementType = mfem::ElementTransformation::ELEMENT;
      47            0 :   Loc1.Transform(face_el->GetNodes(), face_ir);
      48            0 :   mfem::DenseMatrix face_pm;
      49            0 :   fespace.GetMesh()->GetNodes()->GetVectorValues(Loc1.Transf, face_ir, face_pm);
      50              : 
      51              :   // Get coordinates of element dofs.
      52            0 :   mfem::DenseMatrix elem_pm;
      53            0 :   const mfem::FiniteElement *fe_elem = fespace.GetFE(elem_id);
      54            0 :   mfem::IsoparametricTransformation T;
      55            0 :   fespace.GetMesh()->GetElementTransformation(elem_id, &T);
      56            0 :   T.Transform(fe_elem->GetNodes(), elem_pm);
      57              : 
      58              :   // Find the dofs.
      59              :   double tol = 1E-5;
      60              :   mfem::Array<int> elem_dofs, dofs(P);
      61            0 :   fespace.GetElementDofs(elem_id, elem_dofs, dof_trans);
      62            0 :   for (int l = 0; l < P; l++)
      63              :   {
      64            0 :     double norm2_f = 0.0;
      65            0 :     for (int m = 0; m < face_pm.Height(); m++)
      66              :     {
      67            0 :       norm2_f += face_pm(m, l) * face_pm(m, l);
      68              :     }
      69              : 
      70              :     bool found_match = false;
      71              :     MFEM_CONTRACT_VAR(found_match);  // silence unused warning
      72            0 :     for (int m = 0; m < elem_pm.Width(); m++)
      73              :     {
      74            0 :       double norm2_e = 0.0;
      75            0 :       for (int n = 0; n < elem_pm.Height(); n++)
      76              :       {
      77            0 :         norm2_e += elem_pm(n, m) * elem_pm(n, m);
      78              :       }
      79            0 :       double relative_tol = tol * std::max(std::max(norm2_f, norm2_e), 1.0E-6);
      80              :       double diff = 0.0;
      81            0 :       for (int o = 0; o < elem_pm.Height(); o++)
      82              :       {
      83            0 :         diff += std::fabs(elem_pm(o, m) - face_pm(o, l));
      84              :       }
      85            0 :       if (diff <= relative_tol)
      86              :       {
      87            0 :         dofs[l] = elem_dofs[m];
      88              :         found_match = true;
      89            0 :         break;
      90              :       }
      91              :     }
      92              : 
      93              :     MFEM_ASSERT(found_match,
      94              :                 [&]()
      95              :                 {
      96              :                   std::stringstream msg;
      97              :                   msg << "l " << l << '\n';
      98              :                   msg << "elem_dofs\n";
      99              :                   for (auto x : elem_dofs)
     100              :                     msg << x << ' ';
     101              : 
     102              :                   msg << "\ndofs\n";
     103              :                   for (auto x : dofs)
     104              :                     msg << x << ' ';
     105              :                   msg << '\n';
     106              :                   return msg.str();
     107              :                 }());
     108              :   }
     109              : 
     110            0 :   return dofs;
     111            0 : };
     112              : 
     113        36650 : void InitLexicoRestr(const mfem::FiniteElementSpace &fespace,
     114              :                      const std::vector<int> &indices, bool use_bdr, Ceed ceed,
     115              :                      CeedElemRestriction *restr)
     116              : {
     117              :   const std::size_t num_elem = indices.size();
     118              :   const mfem::FiniteElement *fe;
     119              :   bool face_flg = false;
     120        36650 :   if (!use_bdr)
     121              :   {
     122        16050 :     fe = fespace.GetFE(indices[0]);
     123              :   }
     124              :   else
     125              :   {
     126        20600 :     fe = fespace.GetBE(indices[0]);
     127        20600 :     if (!fe)
     128              :     {
     129            0 :       fe = GetTraceElement(fespace, indices);
     130              :       face_flg = true;
     131              :     }
     132              :   }
     133              :   const int P = fe->GetDof();
     134        36650 :   const mfem::TensorBasisElement *tfe = dynamic_cast<const mfem::TensorBasisElement *>(fe);
     135              :   const mfem::Array<int> &dof_map = tfe->GetDofMap();
     136        36650 :   const bool dof_map_is_identity = dof_map.Size() == 0;
     137              :   const CeedInt comp_stride =
     138        33259 :       (fespace.GetVDim() == 1 || fespace.GetOrdering() == mfem::Ordering::byVDIM)
     139        36650 :           ? 1
     140              :           : fespace.GetNDofs();
     141              :   const int stride =
     142        36650 :       (fespace.GetOrdering() == mfem::Ordering::byVDIM) ? fespace.GetVDim() : 1;
     143        36650 :   mfem::Array<int> tp_el_dof(num_elem * P);
     144              :   mfem::Array<bool> tp_el_orients(num_elem * P);
     145              :   int use_el_orients = 0;
     146              : 
     147        36650 :   PalacePragmaOmp(parallel reduction(+ : use_el_orients))
     148              :   {
     149              :     mfem::Array<int> dofs;
     150              :     mfem::DofTransformation dof_trans;
     151              :     bool use_el_orients_loc = false;
     152              : 
     153              :     PalacePragmaOmp(for schedule(static))
     154              :     for (std::size_t i = 0; i < num_elem; i++)
     155              :     {
     156              :       // No need to handle DofTransformation for tensor-product elements.
     157              :       const int e = indices[i];
     158              :       if (use_bdr)
     159              :       {
     160              :         if (!face_flg)
     161              :         {
     162              :           fespace.GetBdrElementDofs(e, dofs, dof_trans);
     163              :         }
     164              :         else
     165              :         {
     166              :           dofs = GetFaceDofsFromAdjacentElement(fespace, dof_trans, P, e);
     167              :         }
     168              :       }
     169              :       else
     170              :       {
     171              :         fespace.GetElementDofs(e, dofs, dof_trans);
     172              :       }
     173              :       MFEM_VERIFY(!dof_trans.GetDofTransformation(),
     174              :                   "Unexpected DofTransformation for lexicographic element "
     175              :                   "restriction.");
     176              :       for (int j = 0; j < P; j++)
     177              :       {
     178              :         const int sdid = dof_map_is_identity ? j : dof_map[j];  // signed
     179              :         const int did = (sdid >= 0) ? sdid : -1 - sdid;
     180              :         const int sgid = dofs[did];  // signed
     181              :         const int gid = (sgid >= 0) ? sgid : -1 - sgid;
     182              :         tp_el_dof[j + P * i] = stride * gid;
     183              :         tp_el_orients[j + P * i] = (sgid >= 0 && sdid < 0) || (sgid < 0 && sdid >= 0);
     184              :         use_el_orients_loc = use_el_orients_loc || tp_el_orients[j + P * i];
     185              :       }
     186              :     }
     187              :     use_el_orients += use_el_orients_loc;
     188              :   }
     189              : 
     190        36650 :   if (use_el_orients)
     191              :   {
     192            0 :     PalaceCeedCall(ceed, CeedElemRestrictionCreateOriented(
     193              :                              ceed, num_elem, P, fespace.GetVDim(), comp_stride,
     194              :                              fespace.GetVDim() * fespace.GetNDofs(), CEED_MEM_HOST,
     195              :                              CEED_COPY_VALUES, tp_el_dof.GetData(), tp_el_orients.GetData(),
     196              :                              restr));
     197              :   }
     198              :   else
     199              :   {
     200        36650 :     PalaceCeedCall(ceed, CeedElemRestrictionCreate(
     201              :                              ceed, num_elem, P, fespace.GetVDim(), comp_stride,
     202              :                              fespace.GetVDim() * fespace.GetNDofs(), CEED_MEM_HOST,
     203              :                              CEED_COPY_VALUES, tp_el_dof.GetData(), restr));
     204              :   }
     205        36650 : }
     206              : 
     207        42023 : void InitNativeRestr(const mfem::FiniteElementSpace &fespace,
     208              :                      const std::vector<int> &indices, bool use_bdr, bool is_interp_range,
     209              :                      Ceed ceed, CeedElemRestriction *restr)
     210              : {
     211              :   const std::size_t num_elem = indices.size();
     212              :   const mfem::FiniteElement *fe;
     213              :   bool face_flg = false;
     214        42023 :   if (!use_bdr)
     215              :   {
     216        31628 :     fe = fespace.GetFE(indices[0]);
     217              :   }
     218              :   else
     219              :   {
     220        10395 :     fe = fespace.GetBE(indices[0]);
     221        10395 :     if (!fe)
     222              :     {
     223            0 :       fe = GetTraceElement(fespace, indices);
     224              :       face_flg = true;
     225              :     }
     226              :   }
     227        42023 :   const int P = fe->GetDof();
     228              :   const CeedInt comp_stride =
     229        23535 :       (fespace.GetVDim() == 1 || fespace.GetOrdering() == mfem::Ordering::byVDIM)
     230        42023 :           ? 1
     231              :           : fespace.GetNDofs();
     232              :   const int stride =
     233        42023 :       (fespace.GetOrdering() == mfem::Ordering::byVDIM) ? fespace.GetVDim() : 1;
     234       126069 :   const bool has_dof_trans = [&]()
     235              :   {
     236        42023 :     if (fespace.GetMesh()->Dimension() < 3)
     237              :     {
     238              :       return false;
     239              :     }
     240        27104 :     const auto geom = fe->GetGeomType();
     241        27104 :     const auto *dof_trans = fespace.FEColl()->DofTransformationForGeometry(geom);
     242        27104 :     return (dof_trans && !dof_trans->IsIdentity());
     243        42023 :   }();
     244        42023 :   mfem::Array<int> tp_el_dof(num_elem * P);
     245              :   mfem::Array<bool> tp_el_orients;
     246              :   mfem::Array<int8_t> tp_el_curl_orients;
     247        42023 :   if (!has_dof_trans)
     248              :   {
     249              :     tp_el_orients.SetSize(num_elem * P);
     250              :   }
     251              :   else
     252              :   {
     253         2222 :     tp_el_curl_orients.SetSize(num_elem * P * 3, 0);
     254              :   }
     255              :   int use_el_orients = 0;
     256              : 
     257        42023 :   PalacePragmaOmp(parallel reduction(+ : use_el_orients))
     258              :   {
     259              :     mfem::Array<int> dofs;
     260              :     mfem::DofTransformation dof_trans;
     261              :     mfem::Vector el_trans_j;
     262              :     if (has_dof_trans)
     263              :     {
     264              :       el_trans_j.SetSize(P);
     265              :       el_trans_j = 0.0;
     266              :     }
     267              :     bool use_el_orients_loc = false;
     268              : 
     269              :     PalacePragmaOmp(for schedule(static))
     270              :     for (std::size_t i = 0; i < num_elem; i++)
     271              :     {
     272              :       const auto e = indices[i];
     273              :       if (use_bdr)
     274              :       {
     275              :         if (!face_flg)
     276              :         {
     277              :           fespace.GetBdrElementDofs(e, dofs, dof_trans);
     278              :         }
     279              :         else
     280              :         {
     281              :           dofs = GetFaceDofsFromAdjacentElement(fespace, dof_trans, P, e);
     282              :         }
     283              :       }
     284              :       else
     285              :       {
     286              :         fespace.GetElementDofs(e, dofs, dof_trans);
     287              :       }
     288              :       if (!has_dof_trans)
     289              :       {
     290              :         for (int j = 0; j < P; j++)
     291              :         {
     292              :           const int sgid = dofs[j];  // signed
     293              :           const int gid = (sgid >= 0) ? sgid : -1 - sgid;
     294              :           tp_el_dof[j + P * i] = stride * gid;
     295              :           tp_el_orients[j + P * i] = (sgid < 0);
     296              :           use_el_orients_loc = use_el_orients_loc || tp_el_orients[j + P * i];
     297              :         }
     298              :       }
     299              :       else
     300              :       {
     301              :         for (int j = 0; j < P; j++)
     302              :         {
     303              :           const int sgid = dofs[j];  // signed
     304              :           const int gid = (sgid >= 0) ? sgid : -1 - sgid;
     305              :           tp_el_dof[j + P * i] = stride * gid;
     306              : 
     307              :           // Fill column j of element tridiagonal matrix tp_el_curl_orients.
     308              :           el_trans_j(j) = 1.0;
     309              :           if (is_interp_range)
     310              :           {
     311              :             dof_trans.InvTransformDual(el_trans_j);
     312              :           }
     313              :           else
     314              :           {
     315              :             dof_trans.InvTransformPrimal(el_trans_j);
     316              :           }
     317              :           double sign_j = (sgid < 0) ? -1.0 : 1.0;
     318              :           tp_el_curl_orients[3 * (j + 0 + P * i) + 1] =
     319              :               static_cast<int8_t>(sign_j * el_trans_j(j));
     320              :           if (j > 0)
     321              :           {
     322              :             tp_el_curl_orients[3 * (j - 1 + P * i) + 2] =
     323              :                 static_cast<int8_t>(sign_j * el_trans_j(j - 1));
     324              :           }
     325              :           if (j < P - 1)
     326              :           {
     327              :             tp_el_curl_orients[3 * (j + 1 + P * i) + 0] =
     328              :                 static_cast<int8_t>(sign_j * el_trans_j(j + 1));
     329              :           }
     330              : 
     331              : #if defined(MFEM_DEBUG)
     332              :           // Debug check that transformation is actually tridiagonal.
     333              :           int nnz = 0;
     334              :           for (int k = 0; k < P; k++)
     335              :           {
     336              :             if ((k < j - 1 || k > j + 1) && el_trans_j(k) != 0.0)
     337              :             {
     338              :               nnz++;
     339              :             }
     340              :           }
     341              :           MFEM_ASSERT(nnz == 0,
     342              :                       "Element transformation matrix is not tridiagonal at column "
     343              :                           << j << " (nnz = " << nnz << ")!");
     344              : #endif
     345              : 
     346              :           // Zero out column vector for next iteration.
     347              :           el_trans_j(j) = 0.0;
     348              :           if (j > 0)
     349              :           {
     350              :             el_trans_j(j - 1) = 0.0;
     351              :           }
     352              :           if (j < P - 1)
     353              :           {
     354              :             el_trans_j(j + 1) = 0.0;
     355              :           }
     356              :         }
     357              :       }
     358              :     }
     359              :     use_el_orients += use_el_orients_loc;
     360              :   }
     361              : 
     362        42023 :   if (has_dof_trans)
     363              :   {
     364         2222 :     PalaceCeedCall(ceed, CeedElemRestrictionCreateCurlOriented(
     365              :                              ceed, num_elem, P, fespace.GetVDim(), comp_stride,
     366              :                              fespace.GetVDim() * fespace.GetNDofs(), CEED_MEM_HOST,
     367              :                              CEED_COPY_VALUES, tp_el_dof.GetData(),
     368              :                              tp_el_curl_orients.GetData(), restr));
     369              :   }
     370        39801 :   else if (use_el_orients)
     371              :   {
     372        11449 :     PalaceCeedCall(ceed, CeedElemRestrictionCreateOriented(
     373              :                              ceed, num_elem, P, fespace.GetVDim(), comp_stride,
     374              :                              fespace.GetVDim() * fespace.GetNDofs(), CEED_MEM_HOST,
     375              :                              CEED_COPY_VALUES, tp_el_dof.GetData(), tp_el_orients.GetData(),
     376              :                              restr));
     377              :   }
     378              :   else
     379              :   {
     380        28352 :     PalaceCeedCall(ceed, CeedElemRestrictionCreate(
     381              :                              ceed, num_elem, P, fespace.GetVDim(), comp_stride,
     382              :                              fespace.GetVDim() * fespace.GetNDofs(), CEED_MEM_HOST,
     383              :                              CEED_COPY_VALUES, tp_el_dof.GetData(), restr));
     384              :   }
     385        42023 : }
     386              : 
     387              : }  // namespace
     388              : 
     389        78673 : void InitRestriction(const mfem::FiniteElementSpace &fespace,
     390              :                      const std::vector<int> &indices, bool use_bdr, bool is_interp,
     391              :                      bool is_interp_range, Ceed ceed, CeedElemRestriction *restr)
     392              : {
     393              :   MFEM_ASSERT(!indices.empty(), "Empty element index set for libCEED element restriction!");
     394              :   if constexpr (false)
     395              :   {
     396              :     std::cout << "New element restriction (" << ceed << ", " << &fespace << ", "
     397              :               << indices[0] << ", " << use_bdr << ", " << is_interp << ", "
     398              :               << is_interp_range << ")\n";
     399              :   }
     400              :   const mfem::FiniteElement *fe;
     401        78673 :   if (!use_bdr)
     402              :   {
     403        47678 :     fe = fespace.GetFE(indices[0]);
     404              :   }
     405              :   else
     406              :   {
     407        30995 :     fe = fespace.GetBE(indices[0]);
     408        30995 :     if (!fe)
     409              :     {
     410            0 :       fe = GetTraceElement(fespace, indices);
     411              :     }
     412              :   }
     413        78673 :   const mfem::TensorBasisElement *tfe = dynamic_cast<const mfem::TensorBasisElement *>(fe);
     414              :   const bool vector = fe->GetRangeType() == mfem::FiniteElement::VECTOR;
     415        78673 :   const bool lexico = (tfe && !vector && !is_interp);
     416              :   if (lexico)
     417              :   {
     418              :     // Lexicographic ordering using dof_map.
     419        36650 :     InitLexicoRestr(fespace, indices, use_bdr, ceed, restr);
     420              :   }
     421              :   else
     422              :   {
     423              :     // Native ordering.
     424        42023 :     InitNativeRestr(fespace, indices, use_bdr, is_interp_range, ceed, restr);
     425              :   }
     426        78673 : }
     427              : 
     428              : }  // namespace palace::ceed
        

Generated by: LCOV version 2.0-1