Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #ifndef PALACE_UTILS_COMMUNICATION_HPP
5 : #define PALACE_UTILS_COMMUNICATION_HPP
6 :
7 : #include <complex>
8 : #include <fmt/color.h>
9 : #include <fmt/format.h>
10 : #include <fmt/printf.h>
11 : #include <fmt/ranges.h>
12 : #include <mfem.hpp>
13 :
14 : namespace palace
15 : {
16 :
17 : namespace mpi
18 : {
19 :
20 : template <typename T>
21 : inline MPI_Datatype DataType();
22 :
23 : template <>
24 : inline MPI_Datatype DataType<char>()
25 : {
26 : return MPI_CHAR;
27 : }
28 :
29 : template <>
30 : inline MPI_Datatype DataType<signed char>()
31 : {
32 : return MPI_SIGNED_CHAR;
33 : }
34 :
35 : template <>
36 : inline MPI_Datatype DataType<unsigned char>()
37 : {
38 : return MPI_UNSIGNED_CHAR;
39 : }
40 :
41 : template <>
42 : inline MPI_Datatype DataType<signed short>()
43 : {
44 : return MPI_SHORT;
45 : }
46 :
47 : template <>
48 : inline MPI_Datatype DataType<unsigned short>()
49 : {
50 : return MPI_UNSIGNED_SHORT;
51 : }
52 :
53 : template <>
54 : inline MPI_Datatype DataType<signed int>()
55 : {
56 : return MPI_INT;
57 : }
58 :
59 : template <>
60 : inline MPI_Datatype DataType<unsigned int>()
61 : {
62 : return MPI_UNSIGNED;
63 : }
64 :
65 : template <>
66 : inline MPI_Datatype DataType<signed long int>()
67 : {
68 : return MPI_LONG;
69 : }
70 :
71 : template <>
72 : inline MPI_Datatype DataType<unsigned long int>()
73 : {
74 : return MPI_UNSIGNED_LONG;
75 : }
76 :
77 : template <>
78 : inline MPI_Datatype DataType<signed long long int>()
79 : {
80 : return MPI_LONG_LONG;
81 : }
82 :
83 : template <>
84 : inline MPI_Datatype DataType<unsigned long long int>()
85 : {
86 : return MPI_UNSIGNED_LONG_LONG;
87 : }
88 :
89 : template <>
90 : inline MPI_Datatype DataType<float>()
91 : {
92 : return MPI_FLOAT;
93 : }
94 :
95 : template <>
96 : inline MPI_Datatype DataType<double>()
97 : {
98 : return MPI_DOUBLE;
99 : }
100 :
101 : template <>
102 : inline MPI_Datatype DataType<long double>()
103 : {
104 : return MPI_LONG_DOUBLE;
105 : }
106 :
107 : template <>
108 : inline MPI_Datatype DataType<std::complex<float>>()
109 : {
110 : return MPI_C_COMPLEX;
111 : }
112 :
113 : template <>
114 : inline MPI_Datatype DataType<std::complex<double>>()
115 : {
116 : return MPI_C_DOUBLE_COMPLEX;
117 : }
118 :
119 : template <>
120 : inline MPI_Datatype DataType<std::complex<long double>>()
121 : {
122 : return MPI_C_LONG_DOUBLE_COMPLEX;
123 : }
124 :
125 : template <>
126 : inline MPI_Datatype DataType<bool>()
127 : {
128 : return MPI_C_BOOL;
129 : }
130 :
131 : template <typename T, typename U>
132 : struct ValueAndLoc
133 : {
134 : T val;
135 : U loc;
136 : };
137 :
138 : template <>
139 : inline MPI_Datatype DataType<ValueAndLoc<float, signed int>>()
140 : {
141 : return MPI_FLOAT_INT;
142 : }
143 :
144 : template <>
145 : inline MPI_Datatype DataType<ValueAndLoc<double, signed int>>()
146 : {
147 : return MPI_DOUBLE_INT;
148 : }
149 :
150 : template <>
151 : inline MPI_Datatype DataType<ValueAndLoc<long double, signed int>>()
152 : {
153 : return MPI_LONG_DOUBLE_INT;
154 : }
155 :
156 : template <>
157 : inline MPI_Datatype DataType<ValueAndLoc<signed short, signed int>>()
158 : {
159 : return MPI_SHORT_INT;
160 : }
161 :
162 : template <>
163 : inline MPI_Datatype DataType<ValueAndLoc<signed int, signed int>>()
164 : {
165 : return MPI_2INT;
166 : }
167 :
168 : template <>
169 : inline MPI_Datatype DataType<ValueAndLoc<signed long int, signed int>>()
170 : {
171 : return MPI_LONG_INT;
172 : }
173 :
174 : } // namespace mpi
175 :
176 : //
177 : // A simple convenience class for easy access to some MPI functionality. This is similar to
178 : // mfem::Mpi and ideally should inherit from it, but the constructor being private instead
179 : // of protected doesn't allow for that.
180 : //
181 : class Mpi
182 : {
183 : public:
184 : // Singleton creation.
185 : static void Init(int requested = default_thread_required)
186 : {
187 : Init(nullptr, nullptr, requested);
188 : }
189 : static void Init(int &argc, char **&argv, int requested = default_thread_required)
190 : {
191 66 : Init(&argc, &argv, requested);
192 : }
193 :
194 : // Finalize MPI (if it has been initialized and not yet already finalized).
195 66 : static void Finalize()
196 : {
197 : if (IsInitialized() && !IsFinalized())
198 : {
199 66 : MPI_Finalize();
200 : }
201 66 : }
202 :
203 : // Return true if MPI has been initialized.
204 : static bool IsInitialized()
205 : {
206 : int is_init;
207 66 : int ierr = MPI_Initialized(&is_init);
208 132 : return ierr == MPI_SUCCESS && is_init;
209 : }
210 :
211 : // Return true if MPI has been finalized.
212 : static bool IsFinalized()
213 : {
214 : int is_finalized;
215 66 : int ierr = MPI_Finalized(&is_finalized);
216 66 : return ierr == MPI_SUCCESS && is_finalized;
217 : }
218 :
219 : // Call MPI_Abort with the given error code.
220 : static void Abort(int code, MPI_Comm comm = World()) { MPI_Abort(comm, code); }
221 :
222 : // Barrier on the communicator.
223 100 : static void Barrier(MPI_Comm comm = World()) { MPI_Barrier(comm); }
224 :
225 : // Return processor's rank in the communicator.
226 : static int Rank(MPI_Comm comm)
227 : {
228 : int rank;
229 820 : MPI_Comm_rank(comm, &rank);
230 812 : return rank;
231 : }
232 :
233 : // Return communicator size.
234 : static int Size(MPI_Comm comm)
235 : {
236 : int size;
237 12007 : MPI_Comm_size(comm, &size);
238 12007 : return size;
239 : }
240 :
241 : // Return communicator size.
242 0 : static bool Root(MPI_Comm comm) { return Rank(comm) == 0; }
243 :
244 : // Wrapper for MPI_AllReduce.
245 : template <typename T>
246 0 : static void GlobalOp(int len, T *buff, MPI_Op op, MPI_Comm comm)
247 : {
248 488 : MPI_Allreduce(MPI_IN_PLACE, buff, len, mpi::DataType<T>(), op, comm);
249 60 : }
250 :
251 : // Global minimum (in-place, result is broadcast to all processes).
252 : template <typename T>
253 : static void GlobalMin(int len, T *buff, MPI_Comm comm)
254 : {
255 : GlobalOp(len, buff, MPI_MIN, comm);
256 84 : }
257 :
258 : // Global maximum (in-place, result is broadcast to all processes).
259 : template <typename T>
260 : static void GlobalMax(int len, T *buff, MPI_Comm comm)
261 : {
262 : GlobalOp(len, buff, MPI_MAX, comm);
263 0 : }
264 :
265 : // Global sum (in-place, result is broadcast to all processes).
266 : template <typename T>
267 : static void GlobalSum(int len, T *buff, MPI_Comm comm)
268 : {
269 0 : GlobalOp(len, buff, MPI_SUM, comm);
270 88 : }
271 :
272 : // Global minimum with index (in-place, result is broadcast to all processes).
273 : template <typename T, typename U>
274 : static void GlobalMinLoc(int len, T *val, U *loc, MPI_Comm comm)
275 : {
276 : std::vector<mpi::ValueAndLoc<T, U>> buffer(len);
277 : for (int i = 0; i < len; i++)
278 : {
279 : buffer[i].val = val[i];
280 : buffer[i].loc = loc[i];
281 : }
282 : GlobalOp(len, buffer.data(), MPI_MINLOC, comm);
283 : for (int i = 0; i < len; i++)
284 : {
285 : val[i] = buffer[i].val;
286 : loc[i] = buffer[i].loc;
287 : }
288 : }
289 :
290 : // Global maximum with index (in-place, result is broadcast to all processes).
291 : template <typename T, typename U>
292 60 : static void GlobalMaxLoc(int len, T *val, U *loc, MPI_Comm comm)
293 : {
294 60 : std::vector<mpi::ValueAndLoc<T, U>> buffer(len);
295 120 : for (int i = 0; i < len; i++)
296 : {
297 60 : buffer[i].val = val[i];
298 60 : buffer[i].loc = loc[i];
299 : }
300 : GlobalOp(len, buffer.data(), MPI_MAXLOC, comm);
301 120 : for (int i = 0; i < len; i++)
302 : {
303 60 : val[i] = buffer[i].val;
304 60 : loc[i] = buffer[i].loc;
305 : }
306 60 : }
307 :
308 : // Global logical or (in-place, result is broadcast to all processes).
309 : static void GlobalOr(int len, bool *buff, MPI_Comm comm)
310 : {
311 : GlobalOp(len, buff, MPI_LOR, comm);
312 39 : }
313 :
314 : // Global logical and (in-place, result is broadcast to all processes).
315 : static void GlobalAnd(int len, bool *buff, MPI_Comm comm)
316 : {
317 : GlobalOp(len, buff, MPI_LAND, comm);
318 0 : }
319 :
320 : // Global broadcast from root.
321 : template <typename T>
322 : static void Broadcast(int len, T *buff, int root, MPI_Comm comm)
323 : {
324 162 : MPI_Bcast(buff, len, mpi::DataType<T>(), root, comm);
325 131 : }
326 :
327 : // Print methods only print on the root process of MPI_COMM_WORLD or a given MPI_Comm.
328 : template <typename... T>
329 540 : static void Print(MPI_Comm comm, fmt::format_string<T...> fmt, T &&...args)
330 : {
331 540 : if (Root(comm))
332 : {
333 : fmt::print(fmt, std::forward<T>(args)...);
334 : }
335 540 : }
336 :
337 : template <typename... T>
338 : static void Print(fmt::format_string<T...> fmt, T &&...args)
339 : {
340 182 : Print(World(), fmt, std::forward<T>(args)...);
341 170 : }
342 :
343 : template <typename... T>
344 : static void Printf(MPI_Comm comm, const char *format, T &&...args)
345 : {
346 : if (Root(comm))
347 : {
348 : fmt::printf(format, std::forward<T>(args)...);
349 : }
350 : }
351 :
352 : template <typename... T>
353 : static void Printf(const char *format, T &&...args)
354 : {
355 : Printf(World(), format, std::forward<T>(args)...);
356 : }
357 :
358 : template <typename... T>
359 8 : static void Warning(MPI_Comm comm, fmt::format_string<T...> fmt, T &&...args)
360 : {
361 8 : Print(comm, "\n{}\n", fmt::styled("--> Warning!", fmt::fg(fmt::color::yellow)));
362 8 : Print(comm, fmt, std::forward<T>(args)...);
363 8 : Print(comm, "\n");
364 8 : }
365 :
366 : template <typename... T>
367 : static void Warning(fmt::format_string<T...> fmt, T &&...args)
368 : {
369 8 : Warning(World(), fmt, std::forward<T>(args)...);
370 8 : }
371 :
372 : // Return the global communicator.
373 : static MPI_Comm World() { return MPI_COMM_WORLD; }
374 :
375 : // Default level of threading used in MPI_Init_thread unless provided to Init.
376 : #if defined(MFEM_USE_OPENMP)
377 : inline static int default_thread_required = MPI_THREAD_FUNNELED;
378 : #else
379 : inline static int default_thread_required = MPI_THREAD_SINGLE;
380 : #endif
381 :
382 : private:
383 : // Prevent direct construction of objects of this class.
384 : Mpi() = default;
385 66 : ~Mpi() { Finalize(); }
386 :
387 : // Access the singleton instance.
388 66 : static Mpi &Instance()
389 : {
390 132 : static Mpi mpi;
391 66 : return mpi;
392 : }
393 :
394 66 : static void Init(int *argc, char ***argv, int requested)
395 : {
396 : // The Mpi object below needs to be created after MPI_Init() for some MPI
397 : // implementations.
398 0 : MFEM_VERIFY(!IsInitialized(), "MPI should not be initialized more than once!");
399 : int provided;
400 66 : MPI_Init_thread(argc, argv, requested, &provided);
401 66 : MFEM_VERIFY(provided >= requested,
402 : "MPI could not provide the requested level of thread support!");
403 : // Initialize the singleton Instance.
404 66 : Instance();
405 66 : }
406 : };
407 :
408 : } // namespace palace
409 :
410 : #endif // PALACE_UTILS_COMMUNICATION_HPP
|