LCOV - code coverage report
Current view: top level - linalg - operator.cpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 14.6 % 261 38
Test Date: 2025-10-23 22:45:05 Functions: 15.8 % 38 6
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #include "operator.hpp"
       5              : 
       6              : #include <mfem/general/forall.hpp>
       7              : #include "linalg/slepc.hpp"
       8              : #include "utils/communication.hpp"
       9              : 
      10              : namespace palace
      11              : {
      12              : 
      13            0 : const Operator *ComplexOperator::Real() const
      14              : {
      15            0 :   MFEM_ABORT("Real() is not implemented for base class ComplexOperator!");
      16              :   return nullptr;
      17              : }
      18              : 
      19            0 : const Operator *ComplexOperator::Imag() const
      20              : {
      21            0 :   MFEM_ABORT("Imag() is not implemented for base class ComplexOperator!");
      22              :   return nullptr;
      23              : }
      24              : 
      25            0 : void ComplexOperator::AssembleDiagonal(ComplexVector &diag) const
      26              : {
      27            0 :   MFEM_ABORT("Base class ComplexOperator does not implement AssembleDiagonal!");
      28              : }
      29              : 
      30            0 : void ComplexOperator::MultTranspose(const ComplexVector &x, ComplexVector &y) const
      31              : {
      32            0 :   MFEM_ABORT("Base class ComplexOperator does not implement MultTranspose!");
      33              : }
      34              : 
      35            0 : void ComplexOperator::MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const
      36              : {
      37            0 :   MFEM_ABORT("Base class ComplexOperator does not implement MultHermitianTranspose!");
      38              : }
      39              : 
      40            0 : void ComplexOperator::AddMult(const ComplexVector &x, ComplexVector &y,
      41              :                               const std::complex<double> a) const
      42              : {
      43            0 :   MFEM_ABORT("Base class ComplexOperator does not implement AddMult!");
      44              : }
      45              : 
      46            0 : void ComplexOperator::AddMultTranspose(const ComplexVector &x, ComplexVector &y,
      47              :                                        const std::complex<double> a) const
      48              : {
      49            0 :   MFEM_ABORT("Base class ComplexOperator does not implement AddMultTranspose!");
      50              : }
      51              : 
      52            0 : void ComplexOperator::AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
      53              :                                                 const std::complex<double> a) const
      54              : {
      55            0 :   MFEM_ABORT("Base class ComplexOperator does not implement AddMultHermitianTranspose!");
      56              : }
      57              : 
      58            9 : ComplexWrapperOperator::ComplexWrapperOperator(std::unique_ptr<Operator> &&dAr,
      59              :                                                std::unique_ptr<Operator> &&dAi,
      60            9 :                                                const Operator *pAr, const Operator *pAi)
      61              :   : ComplexOperator(), data_Ar(std::move(dAr)), data_Ai(std::move(dAi)),
      62           18 :     Ar((data_Ar != nullptr) ? data_Ar.get() : pAr),
      63           27 :     Ai((data_Ai != nullptr) ? data_Ai.get() : pAi)
      64              : {
      65            9 :   MFEM_VERIFY(Ar || Ai, "Cannot construct ComplexWrapperOperator from an empty matrix!");
      66           18 :   MFEM_VERIFY((!Ar || !Ai) || (Ar->Height() == Ai->Height() && Ar->Width() == Ai->Width()),
      67              :               "Mismatch in dimension of real and imaginary matrix parts!");
      68            9 :   tx.UseDevice(true);
      69            9 :   ty.UseDevice(true);
      70            9 :   height = Ar ? Ar->Height() : Ai->Height();
      71            9 :   width = Ar ? Ar->Width() : Ai->Width();
      72            9 : }
      73              : 
      74            9 : ComplexWrapperOperator::ComplexWrapperOperator(std::unique_ptr<Operator> &&Ar,
      75            9 :                                                std::unique_ptr<Operator> &&Ai)
      76            9 :   : ComplexWrapperOperator(std::move(Ar), std::move(Ai), nullptr, nullptr)
      77              : {
      78            9 : }
      79              : 
      80            0 : ComplexWrapperOperator::ComplexWrapperOperator(const Operator *Ar, const Operator *Ai)
      81            0 :   : ComplexWrapperOperator(nullptr, nullptr, Ar, Ai)
      82              : {
      83            0 : }
      84              : 
      85            0 : void ComplexWrapperOperator::AssembleDiagonal(ComplexVector &diag) const
      86              : {
      87              :   diag = 0.0;
      88            0 :   if (Ar)
      89              :   {
      90            0 :     Ar->AssembleDiagonal(diag.Real());
      91              :   }
      92            0 :   if (Ai)
      93              :   {
      94            0 :     Ai->AssembleDiagonal(diag.Imag());
      95              :   }
      96            0 : }
      97              : 
      98            9 : void ComplexWrapperOperator::Mult(const ComplexVector &x, ComplexVector &y) const
      99              : {
     100              :   constexpr bool zero_real = false;
     101              :   constexpr bool zero_imag = false;
     102              :   const Vector &xr = x.Real();
     103              :   const Vector &xi = x.Imag();
     104              :   Vector &yr = y.Real();
     105              :   Vector &yi = y.Imag();
     106            9 :   if (Ai)
     107              :   {
     108              :     if (!zero_imag)
     109              :     {
     110            9 :       Ai->Mult(xi, yr);
     111            9 :       yr *= -1.0;
     112              :     }
     113              :     if (!zero_real)
     114              :     {
     115            9 :       Ai->Mult(xr, yi);
     116              :     }
     117              :   }
     118              :   else
     119              :   {
     120            0 :     yr = 0.0;
     121            0 :     yi = 0.0;
     122              :   }
     123            9 :   if (Ar)
     124              :   {
     125              :     if (!zero_real)
     126              :     {
     127            9 :       Ar->AddMult(xr, yr);
     128              :     }
     129              :     if (!zero_imag)
     130              :     {
     131            9 :       Ar->AddMult(xi, yi);
     132              :     }
     133              :   }
     134            9 : }
     135              : 
     136            0 : void ComplexWrapperOperator::MultTranspose(const ComplexVector &x, ComplexVector &y) const
     137              : {
     138              :   constexpr bool zero_real = false;
     139              :   constexpr bool zero_imag = false;
     140              :   const Vector &xr = x.Real();
     141              :   const Vector &xi = x.Imag();
     142              :   Vector &yr = y.Real();
     143              :   Vector &yi = y.Imag();
     144            0 :   if (Ai)
     145              :   {
     146              :     if (!zero_imag)
     147              :     {
     148            0 :       Ai->MultTranspose(xi, yr);
     149            0 :       yr *= -1.0;
     150              :     }
     151              :     if (!zero_real)
     152              :     {
     153            0 :       Ai->MultTranspose(xr, yi);
     154              :     }
     155              :   }
     156              :   else
     157              :   {
     158            0 :     yr = 0.0;
     159            0 :     yi = 0.0;
     160              :   }
     161            0 :   if (Ar)
     162              :   {
     163              :     if (!zero_real)
     164              :     {
     165            0 :       Ar->AddMultTranspose(xr, yr);
     166              :     }
     167              :     if (!zero_imag)
     168              :     {
     169            0 :       Ar->AddMultTranspose(xi, yi);
     170              :     }
     171              :   }
     172            0 : }
     173              : 
     174            0 : void ComplexWrapperOperator::MultHermitianTranspose(const ComplexVector &x,
     175              :                                                     ComplexVector &y) const
     176              : {
     177              :   constexpr bool zero_real = false;
     178              :   constexpr bool zero_imag = false;
     179              :   const Vector &xr = x.Real();
     180              :   const Vector &xi = x.Imag();
     181              :   Vector &yr = y.Real();
     182              :   Vector &yi = y.Imag();
     183            0 :   if (Ai)
     184              :   {
     185              :     if (!zero_imag)
     186              :     {
     187            0 :       Ai->MultTranspose(xi, yr);
     188              :     }
     189              :     if (!zero_real)
     190              :     {
     191            0 :       Ai->MultTranspose(xr, yi);
     192            0 :       yi *= -1.0;
     193              :     }
     194              :   }
     195              :   else
     196              :   {
     197            0 :     yr = 0.0;
     198            0 :     yi = 0.0;
     199              :   }
     200            0 :   if (Ar)
     201              :   {
     202              :     if (!zero_real)
     203              :     {
     204            0 :       Ar->AddMultTranspose(xr, yr);
     205              :     }
     206              :     if (!zero_imag)
     207              :     {
     208            0 :       Ar->AddMultTranspose(xi, yi);
     209              :     }
     210              :   }
     211            0 : }
     212              : 
     213            0 : void ComplexWrapperOperator::AddMult(const ComplexVector &x, ComplexVector &y,
     214              :                                      const std::complex<double> a) const
     215              : {
     216              :   constexpr bool zero_real = false;
     217              :   constexpr bool zero_imag = false;
     218              :   const Vector &xr = x.Real();
     219              :   const Vector &xi = x.Imag();
     220              :   Vector &yr = y.Real();
     221              :   Vector &yi = y.Imag();
     222            0 :   if (a.real() != 0.0 && a.imag() != 0.0)
     223              :   {
     224            0 :     ty.SetSize(height);
     225            0 :     Mult(x, ty);
     226            0 :     y.AXPY(a, ty);
     227              :   }
     228            0 :   else if (a.real() != 0.0)
     229              :   {
     230            0 :     if (Ar)
     231              :     {
     232              :       if (!zero_real)
     233              :       {
     234            0 :         Ar->AddMult(xr, yr, a.real());
     235              :       }
     236              :       if (!zero_imag)
     237              :       {
     238            0 :         Ar->AddMult(xi, yi, a.real());
     239              :       }
     240              :     }
     241            0 :     if (Ai)
     242              :     {
     243              :       if (!zero_imag)
     244              :       {
     245            0 :         Ai->AddMult(xi, yr, -a.real());
     246              :       }
     247              :       if (!zero_real)
     248              :       {
     249            0 :         Ai->AddMult(xr, yi, a.real());
     250              :       }
     251              :     }
     252              :   }
     253            0 :   else if (a.imag() != 0.0)
     254              :   {
     255            0 :     if (Ar)
     256              :     {
     257              :       if (!zero_real)
     258              :       {
     259            0 :         Ar->AddMult(xr, yi, a.imag());
     260              :       }
     261              :       if (!zero_imag)
     262              :       {
     263            0 :         Ar->AddMult(xi, yr, -a.imag());
     264              :       }
     265              :     }
     266            0 :     if (Ai)
     267              :     {
     268              :       if (!zero_imag)
     269              :       {
     270            0 :         Ai->AddMult(xi, yi, -a.imag());
     271              :       }
     272              :       if (!zero_real)
     273              :       {
     274            0 :         Ai->AddMult(xr, yr, -a.imag());
     275              :       }
     276              :     }
     277              :   }
     278            0 : }
     279              : 
     280            0 : void ComplexWrapperOperator::AddMultTranspose(const ComplexVector &x, ComplexVector &y,
     281              :                                               const std::complex<double> a) const
     282              : {
     283              :   constexpr bool zero_real = false;
     284              :   constexpr bool zero_imag = false;
     285              :   const Vector &xr = x.Real();
     286              :   const Vector &xi = x.Imag();
     287              :   Vector &yr = y.Real();
     288              :   Vector &yi = y.Imag();
     289            0 :   if (a.real() != 0.0 && a.imag() != 0.0)
     290              :   {
     291            0 :     tx.SetSize(width);
     292            0 :     MultTranspose(x, tx);
     293            0 :     y.AXPY(a, tx);
     294              :   }
     295            0 :   else if (a.real() != 0.0)
     296              :   {
     297            0 :     if (Ar)
     298              :     {
     299              :       if (!zero_real)
     300              :       {
     301            0 :         Ar->AddMultTranspose(xr, yr, a.real());
     302              :       }
     303              :       if (!zero_imag)
     304              :       {
     305            0 :         Ar->AddMultTranspose(xi, yi, a.real());
     306              :       }
     307              :     }
     308            0 :     if (Ai)
     309              :     {
     310              :       if (!zero_imag)
     311              :       {
     312            0 :         Ai->AddMultTranspose(xi, yr, -a.real());
     313              :       }
     314              :       if (!zero_real)
     315              :       {
     316            0 :         Ai->AddMultTranspose(xr, yi, a.real());
     317              :       }
     318              :     }
     319              :   }
     320            0 :   else if (a.imag() != 0.0)
     321              :   {
     322            0 :     if (Ar)
     323              :     {
     324              :       if (!zero_real)
     325              :       {
     326            0 :         Ar->AddMultTranspose(xr, yi, a.imag());
     327              :       }
     328              :       if (!zero_imag)
     329              :       {
     330            0 :         Ar->AddMultTranspose(xi, yr, -a.imag());
     331              :       }
     332              :     }
     333            0 :     if (Ai)
     334              :     {
     335              :       if (!zero_imag)
     336              :       {
     337            0 :         Ai->AddMultTranspose(xi, yi, -a.imag());
     338              :       }
     339              :       if (!zero_real)
     340              :       {
     341            0 :         Ai->AddMultTranspose(xr, yr, -a.imag());
     342              :       }
     343              :     }
     344              :   }
     345            0 : }
     346              : 
     347            0 : void ComplexWrapperOperator::AddMultHermitianTranspose(const ComplexVector &x,
     348              :                                                        ComplexVector &y,
     349              :                                                        const std::complex<double> a) const
     350              : {
     351              :   constexpr bool zero_real = false;
     352              :   constexpr bool zero_imag = false;
     353              :   const Vector &xr = x.Real();
     354              :   const Vector &xi = x.Imag();
     355              :   Vector &yr = y.Real();
     356              :   Vector &yi = y.Imag();
     357            0 :   if (a.real() != 0.0 && a.imag() != 0.0)
     358              :   {
     359            0 :     tx.SetSize(width);
     360            0 :     MultHermitianTranspose(x, tx);
     361            0 :     y.AXPY(a, tx);
     362              :   }
     363            0 :   else if (a.real() != 0.0)
     364              :   {
     365            0 :     if (Ar)
     366              :     {
     367              :       if (!zero_real)
     368              :       {
     369            0 :         Ar->AddMultTranspose(xr, yr, a.real());
     370              :       }
     371              :       if (!zero_imag)
     372              :       {
     373            0 :         Ar->AddMultTranspose(xi, yi, a.real());
     374              :       }
     375              :     }
     376            0 :     if (Ai)
     377              :     {
     378              :       if (!zero_imag)
     379              :       {
     380            0 :         Ai->AddMultTranspose(xi, yr, a.real());
     381              :       }
     382              :       if (!zero_real)
     383              :       {
     384            0 :         Ai->AddMultTranspose(xr, yi, -a.real());
     385              :       }
     386              :     }
     387              :   }
     388            0 :   else if (a.imag() != 0.0)
     389              :   {
     390            0 :     if (Ar)
     391              :     {
     392              :       if (!zero_real)
     393              :       {
     394            0 :         Ar->AddMultTranspose(xr, yi, a.imag());
     395              :       }
     396              :       if (!zero_imag)
     397              :       {
     398            0 :         Ar->AddMultTranspose(xi, yr, -a.imag());
     399              :       }
     400              :     }
     401            0 :     if (Ai)
     402              :     {
     403              :       if (!zero_imag)
     404              :       {
     405            0 :         Ai->AddMultTranspose(xi, yi, a.imag());
     406              :       }
     407              :       if (!zero_real)
     408              :       {
     409            0 :         Ai->AddMultTranspose(xr, yr, a.imag());
     410              :       }
     411              :     }
     412              :   }
     413            0 : }
     414              : 
     415            0 : SumOperator::SumOperator(const Operator &op, double a) : Operator(op.Height(), op.Width())
     416              : {
     417            0 :   AddOperator(op, a);
     418              :   z.UseDevice(true);
     419            0 : }
     420              : 
     421           30 : void SumOperator::AddOperator(const Operator &op, double a)
     422              : {
     423           30 :   MFEM_VERIFY(op.Height() == height && op.Width() == width,
     424              :               "Invalid Operator dimensions for SumOperator!");
     425           30 :   ops.emplace_back(&op, a);
     426           30 : }
     427              : 
     428            9 : void SumOperator::Mult(const Vector &x, Vector &y) const
     429              : {
     430            9 :   if (ops.size() == 1)
     431              :   {
     432            0 :     ops.front().first->Mult(x, y);
     433            0 :     if (ops.front().second != 1.0)
     434              :     {
     435            0 :       y *= ops.front().second;
     436              :     }
     437            0 :     return;
     438              :   }
     439            9 :   y = 0.0;
     440            9 :   AddMult(x, y);
     441              : }
     442              : 
     443            0 : void SumOperator::MultTranspose(const Vector &x, Vector &y) const
     444              : {
     445            0 :   if (ops.size() == 1)
     446              :   {
     447            0 :     ops.front().first->MultTranspose(x, y);
     448            0 :     if (ops.front().second != 1.0)
     449              :     {
     450            0 :       y *= ops.front().second;
     451              :     }
     452            0 :     return;
     453              :   }
     454            0 :   y = 0.0;
     455            0 :   AddMultTranspose(x, y);
     456              : }
     457              : 
     458           15 : void SumOperator::AddMult(const Vector &x, Vector &y, const double a) const
     459              : {
     460           15 :   z.SetSize(y.Size());
     461           69 :   for (const auto &[op, c] : ops)
     462              :   {
     463           54 :     op->Mult(x, z);
     464           54 :     y.Add(a * c, z);
     465              :   }
     466           15 : }
     467              : 
     468            0 : void SumOperator::AddMultTranspose(const Vector &x, Vector &y, const double a) const
     469              : {
     470            0 :   z.SetSize(y.Size());
     471            0 :   for (const auto &[op, c] : ops)
     472              :   {
     473            0 :     op->MultTranspose(x, z);
     474            0 :     y.Add(a * c, z);
     475              :   }
     476            0 : }
     477              : 
     478              : template <>
     479            0 : void BaseDiagonalOperator<Operator>::Mult(const Vector &x, Vector &y) const
     480              : {
     481            0 :   const bool use_dev = x.UseDevice() || y.UseDevice();
     482            0 :   const int N = this->height;
     483            0 :   const auto *D = d.Read(use_dev);
     484            0 :   const auto *X = x.Read(use_dev);
     485            0 :   auto *Y = y.Write(use_dev);
     486            0 :   mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { Y[i] = D[i] * X[i]; });
     487            0 : }
     488              : 
     489              : template <>
     490            0 : void BaseDiagonalOperator<ComplexOperator>::Mult(const ComplexVector &x,
     491              :                                                  ComplexVector &y) const
     492              : {
     493            0 :   const bool use_dev = x.UseDevice() || y.UseDevice();
     494            0 :   const int N = this->height;
     495            0 :   const auto *DR = d.Real().Read(use_dev);
     496            0 :   const auto *DI = d.Imag().Read(use_dev);
     497            0 :   const auto *XR = x.Real().Read(use_dev);
     498            0 :   const auto *XI = x.Imag().Read(use_dev);
     499            0 :   auto *YR = y.Real().Write(use_dev);
     500            0 :   auto *YI = y.Imag().Write(use_dev);
     501              :   mfem::forall_switch(use_dev, N,
     502            0 :                       [=] MFEM_HOST_DEVICE(int i)
     503              :                       {
     504            0 :                         YR[i] = DR[i] * XR[i] - DI[i] * XI[i];
     505            0 :                         YI[i] = DI[i] * XR[i] + DR[i] * XI[i];
     506            0 :                       });
     507            0 : }
     508              : 
     509              : template <>
     510            0 : void BaseDiagonalOperator<Operator>::AddMult(const Vector &x, Vector &y,
     511              :                                              const double a) const
     512              : {
     513            0 :   const bool use_dev = x.UseDevice() || y.UseDevice();
     514            0 :   const int N = this->height;
     515            0 :   const auto *D = d.Read(use_dev);
     516            0 :   const auto *X = x.Read(use_dev);
     517            0 :   auto *Y = y.Write(use_dev);
     518            0 :   mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { Y[i] += a * D[i] * X[i]; });
     519            0 : }
     520              : 
     521              : template <>
     522            0 : void BaseDiagonalOperator<ComplexOperator>::AddMult(const ComplexVector &x,
     523              :                                                     ComplexVector &y,
     524              :                                                     const std::complex<double> a) const
     525              : {
     526            0 :   const bool use_dev = x.UseDevice() || y.UseDevice();
     527            0 :   const int N = this->height;
     528              :   const double ar = a.real();
     529              :   const double ai = a.imag();
     530            0 :   const auto *DR = d.Real().Read(use_dev);
     531            0 :   const auto *DI = d.Imag().Read(use_dev);
     532            0 :   const auto *XR = x.Real().Read(use_dev);
     533            0 :   const auto *XI = x.Imag().Read(use_dev);
     534            0 :   auto *YR = y.Real().Write(use_dev);
     535            0 :   auto *YI = y.Imag().Write(use_dev);
     536              :   mfem::forall_switch(use_dev, N,
     537            0 :                       [=] MFEM_HOST_DEVICE(int i)
     538              :                       {
     539            0 :                         const auto tr = DR[i] * XR[i] - DI[i] * XI[i];
     540            0 :                         const auto ti = DI[i] * XR[i] + DR[i] * XI[i];
     541            0 :                         YR[i] += ar * tr - ai * ti;
     542            0 :                         YI[i] += ai * ti + ar * ti;
     543            0 :                       });
     544            0 : }
     545              : 
     546              : template <>
     547            0 : void DiagonalOperatorHelper<BaseDiagonalOperator<ComplexOperator>,
     548              :                             ComplexOperator>::MultHermitianTranspose(const ComplexVector &x,
     549              :                                                                      ComplexVector &y) const
     550              : {
     551            0 :   const ComplexVector &d =
     552              :       static_cast<const BaseDiagonalOperator<ComplexOperator> *>(this)->d;
     553            0 :   const bool use_dev = x.UseDevice() || y.UseDevice();
     554            0 :   const int N = this->height;
     555            0 :   const auto *DR = d.Real().Read(use_dev);
     556            0 :   const auto *DI = d.Imag().Read(use_dev);
     557            0 :   const auto *XR = x.Real().Read(use_dev);
     558            0 :   const auto *XI = x.Imag().Read(use_dev);
     559            0 :   auto *YR = y.Real().Write(use_dev);
     560            0 :   auto *YI = y.Imag().Write(use_dev);
     561              :   mfem::forall_switch(use_dev, N,
     562            0 :                       [=] MFEM_HOST_DEVICE(int i)
     563              :                       {
     564            0 :                         YR[i] = DR[i] * XR[i] + DI[i] * XI[i];
     565            0 :                         YI[i] = -DI[i] * XR[i] + DR[i] * XI[i];
     566            0 :                       });
     567            0 : }
     568              : 
     569              : template <>
     570            0 : void DiagonalOperatorHelper<BaseDiagonalOperator<ComplexOperator>, ComplexOperator>::
     571              :     AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
     572              :                               const std::complex<double> a) const
     573              : {
     574            0 :   const ComplexVector &d =
     575              :       static_cast<const BaseDiagonalOperator<ComplexOperator> *>(this)->d;
     576            0 :   const bool use_dev = x.UseDevice() || y.UseDevice();
     577            0 :   const int N = this->height;
     578              :   const double ar = a.real();
     579              :   const double ai = a.imag();
     580            0 :   const auto *DR = d.Real().Read(use_dev);
     581            0 :   const auto *DI = d.Imag().Read(use_dev);
     582            0 :   const auto *XR = x.Real().Read(use_dev);
     583            0 :   const auto *XI = x.Imag().Read(use_dev);
     584            0 :   auto *YR = y.Real().Write(use_dev);
     585            0 :   auto *YI = y.Imag().Write(use_dev);
     586              :   mfem::forall_switch(use_dev, N,
     587            0 :                       [=] MFEM_HOST_DEVICE(int i)
     588              :                       {
     589            0 :                         const auto tr = DR[i] * XR[i] + DI[i] * XI[i];
     590            0 :                         const auto ti = -DI[i] * XR[i] + DR[i] * XI[i];
     591            0 :                         YR[i] += ar * tr - ai * ti;
     592            0 :                         YI[i] += ai * ti + ar * ti;
     593            0 :                       });
     594            0 : }
     595              : 
     596              : namespace linalg
     597              : {
     598              : 
     599              : template <>
     600            0 : double Norml2(MPI_Comm comm, const Vector &x, const Operator &B, Vector &Bx)
     601              : {
     602            0 :   B.Mult(x, Bx);
     603            0 :   double dot = Dot(comm, Bx, x);
     604              :   MFEM_ASSERT(dot > 0.0,
     605              :               "Non-positive vector norm in normalization (dot = " << dot << ")!");
     606            0 :   return std::sqrt(dot);
     607              : }
     608              : 
     609              : template <>
     610            0 : double Norml2(MPI_Comm comm, const ComplexVector &x, const Operator &B, ComplexVector &Bx)
     611              : {
     612              :   // For SPD B, xᴴ B x is real.
     613            0 :   B.Mult(x.Real(), Bx.Real());
     614            0 :   B.Mult(x.Imag(), Bx.Imag());
     615            0 :   std::complex<double> dot = Dot(comm, Bx, x);
     616              :   MFEM_ASSERT(dot.real() > 0.0 && std::abs(dot.imag()) < 1.0e-9 * dot.real(),
     617              :               "Non-positive vector norm in normalization (dot = " << dot << ")!");
     618            0 :   return std::sqrt(dot.real());
     619              : }
     620              : 
     621            0 : double SpectralNorm(MPI_Comm comm, const Operator &A, bool sym, double tol, int max_it)
     622              : {
     623            0 :   ComplexWrapperOperator Ar(const_cast<Operator *>(&A), nullptr);  // Non-owning constructor
     624            0 :   return SpectralNorm(comm, Ar, sym, tol, max_it);
     625            0 : }
     626              : 
     627            0 : double SpectralNorm(MPI_Comm comm, const ComplexOperator &A, bool herm, double tol,
     628              :                     int max_it)
     629              : {
     630              :   // XX TODO: Use ARPACK or SLEPc for this when configured.
     631              : #if defined(PALACE_WITH_SLEPC)
     632            0 :   return slepc::GetMaxSingularValue(comm, A, herm, tol, max_it);
     633              : #else
     634              :   // Power iteration loop: ||A||₂² = λₙ(Aᴴ A).
     635              :   int it = 0;
     636              :   double res = 0.0;
     637              :   double l = 0.0, l0 = 0.0;
     638              :   ComplexVector u(A.Height()), v(A.Height());
     639              :   u.UseDevice(true);
     640              :   v.UseDevice(true);
     641              :   SetRandom(comm, u);
     642              :   Normalize(comm, u);
     643              :   while (it < max_it)
     644              :   {
     645              :     A.Mult(u, v);
     646              :     if (herm)
     647              :     {
     648              :       u = v;
     649              :     }
     650              :     else
     651              :     {
     652              :       A.MultHermitianTranspose(v, u);
     653              :     }
     654              :     l = Normalize(comm, u);
     655              :     if (it > 0)
     656              :     {
     657              :       res = std::abs(l - l0) / l0;
     658              :       if (res < tol)
     659              :       {
     660              :         break;
     661              :       }
     662              :     }
     663              :     l0 = l;
     664              :     it++;
     665              :   }
     666              :   if (it >= max_it)
     667              :   {
     668              :     Mpi::Warning(comm,
     669              :                  "Power iteration did not converge in {:d} iterations, res = {:.3e}, "
     670              :                  "lambda = {:.3e}!\n",
     671              :                  it, res, l);
     672              :   }
     673              :   return herm ? l : std::sqrt(l);
     674              : #endif
     675              : }
     676              : 
     677              : }  // namespace linalg
     678              : 
     679              : }  // namespace palace
        

Generated by: LCOV version 2.0-1