00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #include "Teuchos_MPIComm.hpp"
00030 #include "Teuchos_ErrorPolling.hpp"
00031
00032
00033 using namespace Teuchos;
00034
00035 namespace Teuchos
00036 {
00037 const int MPIComm::INT = 1;
00038 const int MPIComm::FLOAT = 2;
00039 const int MPIComm::DOUBLE = 3;
00040 const int MPIComm::CHAR = 4;
00041
00042 const int MPIComm::SUM = 5;
00043 const int MPIComm::MIN = 6;
00044 const int MPIComm::MAX = 7;
00045 const int MPIComm::PROD = 8;
00046 }
00047
00048
00049 MPIComm::MPIComm()
00050 :
00051 #ifdef HAVE_MPI
00052 comm_(MPI_COMM_WORLD),
00053 #endif
00054 nProc_(0), myRank_(0)
00055 {
00056 init();
00057 }
00058
00059 #ifdef HAVE_MPI
00060 MPIComm::MPIComm(MPI_Comm comm)
00061 : comm_(comm), nProc_(0), myRank_(0)
00062 {
00063 init();
00064 }
00065 #endif
00066
00067 int MPIComm::mpiIsRunning() const
00068 {
00069 int mpiStarted = 0;
00070 #ifdef HAVE_MPI
00071 MPI_Initialized(&mpiStarted);
00072 #endif
00073 return mpiStarted;
00074 }
00075
00076 void MPIComm::init()
00077 {
00078 #ifdef HAVE_MPI
00079
00080 if (mpiIsRunning())
00081 {
00082 errCheck(MPI_Comm_rank(comm_, &myRank_), "Comm_rank");
00083 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00084 }
00085 else
00086 {
00087 nProc_ = 1;
00088 myRank_ = 0;
00089 }
00090
00091 #else
00092 nProc_ = 1;
00093 myRank_ = 0;
00094 #endif
00095 }
00096
00097 #ifdef USE_MPI_GROUPS
00098
00099 MPIComm::MPIComm(const MPIComm& parent, const MPIGroup& group)
00100 :
00101 #ifdef HAVE_MPI
00102 comm_(MPI_COMM_WORLD),
00103 #endif
00104 nProc_(0), myRank_(0)
00105 {
00106 #ifdef HAVE_MPI
00107 if (group.getNProc()==0)
00108 {
00109 rank_ = -1;
00110 nProc_ = 0;
00111 }
00112 else if (parent.containsMe())
00113 {
00114 MPI_Comm parentComm = parent.comm_;
00115 MPI_Group newGroup = group.group_;
00116
00117 errCheck(MPI_Comm_create(parentComm, newGroup, &comm_),
00118 "Comm_create");
00119
00120 if (group.containsProc(parent.getRank()))
00121 {
00122 errCheck(MPI_Comm_rank(comm_, &rank_), "Comm_rank");
00123
00124 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00125 }
00126 else
00127 {
00128 rank_ = -1;
00129 nProc_ = -1;
00130 return;
00131 }
00132 }
00133 else
00134 {
00135 rank_ = -1;
00136 nProc_ = -1;
00137 }
00138 #endif
00139 }
00140
00141 #endif
00142
00143 MPIComm& MPIComm::world()
00144 {
00145 static MPIComm w = MPIComm();
00146 return w;
00147 }
00148
00149
00150 void MPIComm::synchronize() const
00151 {
00152 #ifdef HAVE_MPI
00153
00154 {
00155 if (mpiIsRunning())
00156 {
00157
00158
00159 TEUCHOS_POLL_FOR_FAILURES(*this);
00160
00161
00162 errCheck(::MPI_Barrier(comm_), "Barrier");
00163 }
00164 }
00165
00166 #endif
00167 }
00168
00169 void MPIComm::allToAll(void* sendBuf, int sendCount, int sendType,
00170 void* recvBuf, int recvCount, int recvType) const
00171 {
00172 #ifdef HAVE_MPI
00173
00174 {
00175 MPI_Datatype mpiSendType = getDataType(sendType);
00176 MPI_Datatype mpiRecvType = getDataType(recvType);
00177
00178
00179 if (mpiIsRunning())
00180 {
00181
00182
00183 TEUCHOS_POLL_FOR_FAILURES(*this);
00184
00185
00186 errCheck(::MPI_Alltoall(sendBuf, sendCount, mpiSendType,
00187 recvBuf, recvCount, mpiRecvType,
00188 comm_), "Alltoall");
00189 }
00190 }
00191
00192 #endif
00193 }
00194
00195 void MPIComm::allToAllv(void* sendBuf, int* sendCount,
00196 int* sendDisplacements, int sendType,
00197 void* recvBuf, int* recvCount,
00198 int* recvDisplacements, int recvType) const
00199 {
00200 #ifdef HAVE_MPI
00201
00202 {
00203 MPI_Datatype mpiSendType = getDataType(sendType);
00204 MPI_Datatype mpiRecvType = getDataType(recvType);
00205
00206 if (mpiIsRunning())
00207 {
00208
00209
00210 TEUCHOS_POLL_FOR_FAILURES(*this);
00211
00212
00213 errCheck(::MPI_Alltoallv(sendBuf, sendCount, sendDisplacements, mpiSendType,
00214 recvBuf, recvCount, recvDisplacements, mpiRecvType,
00215 comm_), "Alltoallv");
00216 }
00217 }
00218
00219 #endif
00220 }
00221
00222 void MPIComm::gather(void* sendBuf, int sendCount, int sendType,
00223 void* recvBuf, int recvCount, int recvType,
00224 int root) const
00225 {
00226 #ifdef HAVE_MPI
00227
00228 {
00229 MPI_Datatype mpiSendType = getDataType(sendType);
00230 MPI_Datatype mpiRecvType = getDataType(recvType);
00231
00232
00233 if (mpiIsRunning())
00234 {
00235
00236
00237 TEUCHOS_POLL_FOR_FAILURES(*this);
00238
00239
00240 errCheck(::MPI_Gather(sendBuf, sendCount, mpiSendType,
00241 recvBuf, recvCount, mpiRecvType,
00242 root, comm_), "Gather");
00243 }
00244 }
00245
00246 #endif
00247 }
00248
00249 void MPIComm::gatherv(void* sendBuf, int sendCount, int sendType,
00250 void* recvBuf, int* recvCount, int* displacements, int recvType,
00251 int root) const
00252 {
00253 #ifdef HAVE_MPI
00254
00255 {
00256 MPI_Datatype mpiSendType = getDataType(sendType);
00257 MPI_Datatype mpiRecvType = getDataType(recvType);
00258
00259 if (mpiIsRunning())
00260 {
00261
00262
00263 TEUCHOS_POLL_FOR_FAILURES(*this);
00264
00265
00266 errCheck(::MPI_Gatherv(sendBuf, sendCount, mpiSendType,
00267 recvBuf, recvCount, displacements, mpiRecvType,
00268 root, comm_), "Gatherv");
00269 }
00270 }
00271
00272 #endif
00273 }
00274
00275 void MPIComm::allGather(void* sendBuf, int sendCount, int sendType,
00276 void* recvBuf, int recvCount,
00277 int recvType) const
00278 {
00279 #ifdef HAVE_MPI
00280
00281 {
00282 MPI_Datatype mpiSendType = getDataType(sendType);
00283 MPI_Datatype mpiRecvType = getDataType(recvType);
00284
00285 if (mpiIsRunning())
00286 {
00287
00288
00289 TEUCHOS_POLL_FOR_FAILURES(*this);
00290
00291
00292 errCheck(::MPI_Allgather(sendBuf, sendCount, mpiSendType,
00293 recvBuf, recvCount,
00294 mpiRecvType, comm_),
00295 "AllGather");
00296 }
00297 }
00298
00299 #endif
00300 }
00301
00302
00303 void MPIComm::allGatherv(void* sendBuf, int sendCount, int sendType,
00304 void* recvBuf, int* recvCount,
00305 int* recvDisplacements,
00306 int recvType) const
00307 {
00308 #ifdef HAVE_MPI
00309
00310 {
00311 MPI_Datatype mpiSendType = getDataType(sendType);
00312 MPI_Datatype mpiRecvType = getDataType(recvType);
00313
00314 if (mpiIsRunning())
00315 {
00316
00317
00318 TEUCHOS_POLL_FOR_FAILURES(*this);
00319
00320
00321 errCheck(::MPI_Allgatherv(sendBuf, sendCount, mpiSendType,
00322 recvBuf, recvCount, recvDisplacements,
00323 mpiRecvType,
00324 comm_),
00325 "AllGatherv");
00326 }
00327 }
00328
00329 #endif
00330 }
00331
00332
00333 void MPIComm::bcast(void* msg, int length, int type, int src) const
00334 {
00335 #ifdef HAVE_MPI
00336
00337 {
00338 if (mpiIsRunning())
00339 {
00340
00341
00342 TEUCHOS_POLL_FOR_FAILURES(*this);
00343
00344
00345 MPI_Datatype mpiType = getDataType(type);
00346 errCheck(::MPI_Bcast(msg, length, mpiType, src,
00347 comm_), "Bcast");
00348 }
00349 }
00350
00351 #endif
00352 }
00353
00354 void MPIComm::allReduce(void* input, void* result, int inputCount,
00355 int type, int op) const
00356 {
00357 #ifdef HAVE_MPI
00358
00359
00360 {
00361 MPI_Op mpiOp = getOp(op);
00362 MPI_Datatype mpiType = getDataType(type);
00363
00364 if (mpiIsRunning())
00365 {
00366 errCheck(::MPI_Allreduce(input, result, inputCount, mpiType,
00367 mpiOp, comm_),
00368 "Allreduce");
00369 }
00370 }
00371
00372 #endif
00373 }
00374
00375
00376 #ifdef HAVE_MPI
00377
00378 MPI_Datatype MPIComm::getDataType(int type)
00379 {
00380 TEST_FOR_EXCEPTION(
00381 !(type == INT || type==FLOAT
00382 || type==DOUBLE || type==CHAR),
00383 std::range_error,
00384 "invalid type " << type << " in MPIComm::getDataType");
00385
00386 if(type == INT) return MPI_INT;
00387 if(type == FLOAT) return MPI_FLOAT;
00388 if(type == DOUBLE) return MPI_DOUBLE;
00389
00390 return MPI_CHAR;
00391 }
00392
00393
00394 void MPIComm::errCheck(int errCode, const std::string& methodName)
00395 {
00396 TEST_FOR_EXCEPTION(errCode != 0, std::runtime_error,
00397 "MPI function MPI_" << methodName
00398 << " returned error code=" << errCode);
00399 }
00400
00401 MPI_Op MPIComm::getOp(int op)
00402 {
00403
00404 TEST_FOR_EXCEPTION(
00405 !(op == SUM || op==MAX
00406 || op==MIN || op==PROD),
00407 std::range_error,
00408 "invalid operator "
00409 << op << " in MPIComm::getOp");
00410
00411 if( op == SUM) return MPI_SUM;
00412 else if( op == MAX) return MPI_MAX;
00413 else if( op == MIN) return MPI_MIN;
00414 return MPI_PROD;
00415 }
00416
00417 #endif