Line data Source code
1 : // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 : // SPDX-License-Identifier: Apache-2.0
3 :
4 : #include "device.hpp"
5 : #include "communication.hpp"
6 :
7 : #include <mfem.hpp>
8 :
9 : namespace palace::utils
10 : {
11 :
12 66 : int GetDeviceCount()
13 : {
14 : #if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
15 : return mfem::Device::GetDeviceCount();
16 : #else
17 66 : return 0;
18 : #endif
19 : }
20 :
21 66 : int GetDeviceId(MPI_Comm comm, int ngpu)
22 : {
23 : // Assign devices round-robin over MPI ranks if GPU support is enabled.
24 : #if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
25 : MPI_Comm node_comm;
26 : MPI_Comm_split_type(comm, MPI_COMM_TYPE_SHARED, Mpi::Rank(comm), MPI_INFO_NULL,
27 : &node_comm);
28 : int node_size = Mpi::Rank(node_comm);
29 : MPI_Comm_free(&node_comm);
30 : return node_size % ngpu;
31 : #else
32 66 : return 0;
33 : #endif
34 : }
35 :
36 : } // namespace palace::utils
|