LCOV - code coverage report
Current view: top level - utils - communication.hpp (source / functions) Coverage Total Hit
Test: Palace Coverage Report Lines: 89.3 % 56 50
Test Date: 2025-10-23 22:45:05 Functions: 25.7 % 74 19
Legend: Lines: hit not hit

            Line data    Source code
       1              : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
       2              : // SPDX-License-Identifier: Apache-2.0
       3              : 
       4              : #ifndef PALACE_UTILS_COMMUNICATION_HPP
       5              : #define PALACE_UTILS_COMMUNICATION_HPP
       6              : 
       7              : #include <complex>
       8              : #include <fmt/color.h>
       9              : #include <fmt/format.h>
      10              : #include <fmt/printf.h>
      11              : #include <fmt/ranges.h>
      12              : #include <mfem.hpp>
      13              : 
      14              : namespace palace
      15              : {
      16              : 
      17              : namespace mpi
      18              : {
      19              : 
      20              : template <typename T>
      21              : inline MPI_Datatype DataType();
      22              : 
      23              : template <>
      24              : inline MPI_Datatype DataType<char>()
      25              : {
      26              :   return MPI_CHAR;
      27              : }
      28              : 
      29              : template <>
      30              : inline MPI_Datatype DataType<signed char>()
      31              : {
      32              :   return MPI_SIGNED_CHAR;
      33              : }
      34              : 
      35              : template <>
      36              : inline MPI_Datatype DataType<unsigned char>()
      37              : {
      38              :   return MPI_UNSIGNED_CHAR;
      39              : }
      40              : 
      41              : template <>
      42              : inline MPI_Datatype DataType<signed short>()
      43              : {
      44              :   return MPI_SHORT;
      45              : }
      46              : 
      47              : template <>
      48              : inline MPI_Datatype DataType<unsigned short>()
      49              : {
      50              :   return MPI_UNSIGNED_SHORT;
      51              : }
      52              : 
      53              : template <>
      54              : inline MPI_Datatype DataType<signed int>()
      55              : {
      56              :   return MPI_INT;
      57              : }
      58              : 
      59              : template <>
      60              : inline MPI_Datatype DataType<unsigned int>()
      61              : {
      62              :   return MPI_UNSIGNED;
      63              : }
      64              : 
      65              : template <>
      66              : inline MPI_Datatype DataType<signed long int>()
      67              : {
      68              :   return MPI_LONG;
      69              : }
      70              : 
      71              : template <>
      72              : inline MPI_Datatype DataType<unsigned long int>()
      73              : {
      74              :   return MPI_UNSIGNED_LONG;
      75              : }
      76              : 
      77              : template <>
      78              : inline MPI_Datatype DataType<signed long long int>()
      79              : {
      80              :   return MPI_LONG_LONG;
      81              : }
      82              : 
      83              : template <>
      84              : inline MPI_Datatype DataType<unsigned long long int>()
      85              : {
      86              :   return MPI_UNSIGNED_LONG_LONG;
      87              : }
      88              : 
      89              : template <>
      90              : inline MPI_Datatype DataType<float>()
      91              : {
      92              :   return MPI_FLOAT;
      93              : }
      94              : 
      95              : template <>
      96              : inline MPI_Datatype DataType<double>()
      97              : {
      98              :   return MPI_DOUBLE;
      99              : }
     100              : 
     101              : template <>
     102              : inline MPI_Datatype DataType<long double>()
     103              : {
     104              :   return MPI_LONG_DOUBLE;
     105              : }
     106              : 
     107              : template <>
     108              : inline MPI_Datatype DataType<std::complex<float>>()
     109              : {
     110              :   return MPI_C_COMPLEX;
     111              : }
     112              : 
     113              : template <>
     114              : inline MPI_Datatype DataType<std::complex<double>>()
     115              : {
     116              :   return MPI_C_DOUBLE_COMPLEX;
     117              : }
     118              : 
     119              : template <>
     120              : inline MPI_Datatype DataType<std::complex<long double>>()
     121              : {
     122              :   return MPI_C_LONG_DOUBLE_COMPLEX;
     123              : }
     124              : 
     125              : template <>
     126              : inline MPI_Datatype DataType<bool>()
     127              : {
     128              :   return MPI_C_BOOL;
     129              : }
     130              : 
     131              : template <typename T, typename U>
     132              : struct ValueAndLoc
     133              : {
     134              :   T val;
     135              :   U loc;
     136              : };
     137              : 
     138              : template <>
     139              : inline MPI_Datatype DataType<ValueAndLoc<float, signed int>>()
     140              : {
     141              :   return MPI_FLOAT_INT;
     142              : }
     143              : 
     144              : template <>
     145              : inline MPI_Datatype DataType<ValueAndLoc<double, signed int>>()
     146              : {
     147              :   return MPI_DOUBLE_INT;
     148              : }
     149              : 
     150              : template <>
     151              : inline MPI_Datatype DataType<ValueAndLoc<long double, signed int>>()
     152              : {
     153              :   return MPI_LONG_DOUBLE_INT;
     154              : }
     155              : 
     156              : template <>
     157              : inline MPI_Datatype DataType<ValueAndLoc<signed short, signed int>>()
     158              : {
     159              :   return MPI_SHORT_INT;
     160              : }
     161              : 
     162              : template <>
     163              : inline MPI_Datatype DataType<ValueAndLoc<signed int, signed int>>()
     164              : {
     165              :   return MPI_2INT;
     166              : }
     167              : 
     168              : template <>
     169              : inline MPI_Datatype DataType<ValueAndLoc<signed long int, signed int>>()
     170              : {
     171              :   return MPI_LONG_INT;
     172              : }
     173              : 
     174              : }  // namespace mpi
     175              : 
     176              : //
     177              : // A simple convenience class for easy access to some MPI functionality. This is similar to
     178              : // mfem::Mpi and ideally should inherit from it, but the constructor being private instead
     179              : // of protected doesn't allow for that.
     180              : //
     181              : class Mpi
     182              : {
     183              : public:
     184              :   // Singleton creation.
     185              :   static void Init(int requested = default_thread_required)
     186              :   {
     187              :     Init(nullptr, nullptr, requested);
     188              :   }
     189              :   static void Init(int &argc, char **&argv, int requested = default_thread_required)
     190              :   {
     191           66 :     Init(&argc, &argv, requested);
     192              :   }
     193              : 
     194              :   // Finalize MPI (if it has been initialized and not yet already finalized).
     195           66 :   static void Finalize()
     196              :   {
     197              :     if (IsInitialized() && !IsFinalized())
     198              :     {
     199           66 :       MPI_Finalize();
     200              :     }
     201           66 :   }
     202              : 
     203              :   // Return true if MPI has been initialized.
     204              :   static bool IsInitialized()
     205              :   {
     206              :     int is_init;
     207           66 :     int ierr = MPI_Initialized(&is_init);
     208          132 :     return ierr == MPI_SUCCESS && is_init;
     209              :   }
     210              : 
     211              :   // Return true if MPI has been finalized.
     212              :   static bool IsFinalized()
     213              :   {
     214              :     int is_finalized;
     215           66 :     int ierr = MPI_Finalized(&is_finalized);
     216           66 :     return ierr == MPI_SUCCESS && is_finalized;
     217              :   }
     218              : 
     219              :   // Call MPI_Abort with the given error code.
     220              :   static void Abort(int code, MPI_Comm comm = World()) { MPI_Abort(comm, code); }
     221              : 
     222              :   // Barrier on the communicator.
     223          100 :   static void Barrier(MPI_Comm comm = World()) { MPI_Barrier(comm); }
     224              : 
     225              :   // Return processor's rank in the communicator.
     226              :   static int Rank(MPI_Comm comm)
     227              :   {
     228              :     int rank;
     229          820 :     MPI_Comm_rank(comm, &rank);
     230          812 :     return rank;
     231              :   }
     232              : 
     233              :   // Return communicator size.
     234              :   static int Size(MPI_Comm comm)
     235              :   {
     236              :     int size;
     237        12007 :     MPI_Comm_size(comm, &size);
     238        12007 :     return size;
     239              :   }
     240              : 
     241              :   // Return communicator size.
     242            0 :   static bool Root(MPI_Comm comm) { return Rank(comm) == 0; }
     243              : 
     244              :   // Wrapper for MPI_AllReduce.
     245              :   template <typename T>
     246            0 :   static void GlobalOp(int len, T *buff, MPI_Op op, MPI_Comm comm)
     247              :   {
     248          488 :     MPI_Allreduce(MPI_IN_PLACE, buff, len, mpi::DataType<T>(), op, comm);
     249           60 :   }
     250              : 
     251              :   // Global minimum (in-place, result is broadcast to all processes).
     252              :   template <typename T>
     253              :   static void GlobalMin(int len, T *buff, MPI_Comm comm)
     254              :   {
     255              :     GlobalOp(len, buff, MPI_MIN, comm);
     256           84 :   }
     257              : 
     258              :   // Global maximum (in-place, result is broadcast to all processes).
     259              :   template <typename T>
     260              :   static void GlobalMax(int len, T *buff, MPI_Comm comm)
     261              :   {
     262              :     GlobalOp(len, buff, MPI_MAX, comm);
     263            0 :   }
     264              : 
     265              :   // Global sum (in-place, result is broadcast to all processes).
     266              :   template <typename T>
     267              :   static void GlobalSum(int len, T *buff, MPI_Comm comm)
     268              :   {
     269            0 :     GlobalOp(len, buff, MPI_SUM, comm);
     270           88 :   }
     271              : 
     272              :   // Global minimum with index (in-place, result is broadcast to all processes).
     273              :   template <typename T, typename U>
     274              :   static void GlobalMinLoc(int len, T *val, U *loc, MPI_Comm comm)
     275              :   {
     276              :     std::vector<mpi::ValueAndLoc<T, U>> buffer(len);
     277              :     for (int i = 0; i < len; i++)
     278              :     {
     279              :       buffer[i].val = val[i];
     280              :       buffer[i].loc = loc[i];
     281              :     }
     282              :     GlobalOp(len, buffer.data(), MPI_MINLOC, comm);
     283              :     for (int i = 0; i < len; i++)
     284              :     {
     285              :       val[i] = buffer[i].val;
     286              :       loc[i] = buffer[i].loc;
     287              :     }
     288              :   }
     289              : 
     290              :   // Global maximum with index (in-place, result is broadcast to all processes).
     291              :   template <typename T, typename U>
     292           60 :   static void GlobalMaxLoc(int len, T *val, U *loc, MPI_Comm comm)
     293              :   {
     294           60 :     std::vector<mpi::ValueAndLoc<T, U>> buffer(len);
     295          120 :     for (int i = 0; i < len; i++)
     296              :     {
     297           60 :       buffer[i].val = val[i];
     298           60 :       buffer[i].loc = loc[i];
     299              :     }
     300              :     GlobalOp(len, buffer.data(), MPI_MAXLOC, comm);
     301          120 :     for (int i = 0; i < len; i++)
     302              :     {
     303           60 :       val[i] = buffer[i].val;
     304           60 :       loc[i] = buffer[i].loc;
     305              :     }
     306           60 :   }
     307              : 
     308              :   // Global logical or (in-place, result is broadcast to all processes).
     309              :   static void GlobalOr(int len, bool *buff, MPI_Comm comm)
     310              :   {
     311              :     GlobalOp(len, buff, MPI_LOR, comm);
     312           39 :   }
     313              : 
     314              :   // Global logical and (in-place, result is broadcast to all processes).
     315              :   static void GlobalAnd(int len, bool *buff, MPI_Comm comm)
     316              :   {
     317              :     GlobalOp(len, buff, MPI_LAND, comm);
     318            0 :   }
     319              : 
     320              :   // Global broadcast from root.
     321              :   template <typename T>
     322              :   static void Broadcast(int len, T *buff, int root, MPI_Comm comm)
     323              :   {
     324          162 :     MPI_Bcast(buff, len, mpi::DataType<T>(), root, comm);
     325          131 :   }
     326              : 
     327              :   // Print methods only print on the root process of MPI_COMM_WORLD or a given MPI_Comm.
     328              :   template <typename... T>
     329          540 :   static void Print(MPI_Comm comm, fmt::format_string<T...> fmt, T &&...args)
     330              :   {
     331          540 :     if (Root(comm))
     332              :     {
     333              :       fmt::print(fmt, std::forward<T>(args)...);
     334              :     }
     335          540 :   }
     336              : 
     337              :   template <typename... T>
     338              :   static void Print(fmt::format_string<T...> fmt, T &&...args)
     339              :   {
     340          182 :     Print(World(), fmt, std::forward<T>(args)...);
     341          170 :   }
     342              : 
     343              :   template <typename... T>
     344              :   static void Printf(MPI_Comm comm, const char *format, T &&...args)
     345              :   {
     346              :     if (Root(comm))
     347              :     {
     348              :       fmt::printf(format, std::forward<T>(args)...);
     349              :     }
     350              :   }
     351              : 
     352              :   template <typename... T>
     353              :   static void Printf(const char *format, T &&...args)
     354              :   {
     355              :     Printf(World(), format, std::forward<T>(args)...);
     356              :   }
     357              : 
     358              :   template <typename... T>
     359            8 :   static void Warning(MPI_Comm comm, fmt::format_string<T...> fmt, T &&...args)
     360              :   {
     361            8 :     Print(comm, "\n{}\n", fmt::styled("--> Warning!", fmt::fg(fmt::color::yellow)));
     362            8 :     Print(comm, fmt, std::forward<T>(args)...);
     363            8 :     Print(comm, "\n");
     364            8 :   }
     365              : 
     366              :   template <typename... T>
     367              :   static void Warning(fmt::format_string<T...> fmt, T &&...args)
     368              :   {
     369            8 :     Warning(World(), fmt, std::forward<T>(args)...);
     370            8 :   }
     371              : 
     372              :   // Return the global communicator.
     373              :   static MPI_Comm World() { return MPI_COMM_WORLD; }
     374              : 
     375              :   // Default level of threading used in MPI_Init_thread unless provided to Init.
     376              : #if defined(MFEM_USE_OPENMP)
     377              :   inline static int default_thread_required = MPI_THREAD_FUNNELED;
     378              : #else
     379              :   inline static int default_thread_required = MPI_THREAD_SINGLE;
     380              : #endif
     381              : 
     382              : private:
     383              :   // Prevent direct construction of objects of this class.
     384              :   Mpi() = default;
     385           66 :   ~Mpi() { Finalize(); }
     386              : 
     387              :   // Access the singleton instance.
     388           66 :   static Mpi &Instance()
     389              :   {
     390          132 :     static Mpi mpi;
     391           66 :     return mpi;
     392              :   }
     393              : 
     394           66 :   static void Init(int *argc, char ***argv, int requested)
     395              :   {
     396              :     // The Mpi object below needs to be created after MPI_Init() for some MPI
     397              :     // implementations.
     398            0 :     MFEM_VERIFY(!IsInitialized(), "MPI should not be initialized more than once!");
     399              :     int provided;
     400           66 :     MPI_Init_thread(argc, argv, requested, &provided);
     401           66 :     MFEM_VERIFY(provided >= requested,
     402              :                 "MPI could not provide the requested level of thread support!");
     403              :     // Initialize the singleton Instance.
     404           66 :     Instance();
     405           66 :   }
     406              : };
     407              : 
     408              : }  // namespace palace
     409              : 
     410              : #endif  // PALACE_UTILS_COMMUNICATION_HPP
        

Generated by: LCOV version 2.0-1