custom_all_reduce_test.cu 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. /**
  2. * This is a standalone test for custom allreduce.
  3. * To compile, make sure you have MPI and NCCL installed in your system.
  4. * export MPI_HOME=XXX
  5. * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
  6. * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
  7. *
  8. * Warning: this C++ test is not designed to be very readable and was used
  9. * during the rapid prototyping process.
  10. *
  11. * To run:
  12. * mpirun -np 8 ./custom_all_reduce_test
  13. */
  14. #include <cuda.h>
  15. #include <curand_kernel.h>
  16. #include <stdio.h>
  17. #include <stdlib.h>
  18. #include <limits>
  19. #include <vector>
  20. #include "cuda_profiler_api.h"
  21. #include "custom_all_reduce.cuh"
  22. #include "mpi.h"
  23. #include "nccl.h"
  24. #define MPICHECK(cmd) \
  25. do { \
  26. int e = cmd; \
  27. if (e != MPI_SUCCESS) { \
  28. printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
  29. exit(EXIT_FAILURE); \
  30. } \
  31. } while (0)
  32. #define NCCLCHECK(cmd) \
  33. do { \
  34. ncclResult_t r = cmd; \
  35. if (r != ncclSuccess) { \
  36. printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
  37. ncclGetErrorString(r)); \
  38. exit(EXIT_FAILURE); \
  39. } \
  40. } while (0)
  41. __global__ void dummy_kernel() {
  42. for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
  43. }
  44. template <typename T>
  45. __global__ void set_data(T *data, int size, int myRank) {
  46. for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
  47. idx += gridDim.x * blockDim.x) {
  48. data[idx] = myRank * 0.11f;
  49. }
  50. }
  51. template <typename T>
  52. __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
  53. double *fdata2, int size) {
  54. for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
  55. idx += gridDim.x * blockDim.x) {
  56. fdata1[idx] = data1[idx];
  57. fdata2[idx] = data2[idx];
  58. }
  59. }
  60. __global__ void init_rand(curandState_t *state, int size, int nRanks) {
  61. for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
  62. idx += gridDim.x * blockDim.x) {
  63. for (int i = 0; i < nRanks; i++) {
  64. curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
  65. }
  66. }
  67. }
  68. template <typename T>
  69. __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
  70. int myRank, int nRanks, int size) {
  71. for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
  72. idx += gridDim.x * blockDim.x) {
  73. double sum = 0.0;
  74. for (int i = 0; i < nRanks; i++) {
  75. double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
  76. T hval = val; // downcast first
  77. sum += static_cast<double>(hval);
  78. if (i == myRank) data[idx] = hval;
  79. }
  80. ground_truth[idx] = sum;
  81. }
  82. }
  83. template <typename T>
  84. void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
  85. int data_size) {
  86. T *result;
  87. cudaStream_t stream;
  88. CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
  89. CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
  90. CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
  91. cudaIpcMemHandle_t self_data_handle;
  92. cudaIpcMemHandle_t data_handles[8];
  93. aphrodite::Metadata *buffer;
  94. T *self_data_copy;
  95. /**
  96. * Allocate IPC buffer
  97. *
  98. * The first section is a temporary buffer for storing intermediate allreduce
  99. * results, if a particular algorithm requires it. The second section is for
  100. * the input to the allreduce. The actual API takes the input pointer as an
  101. * argument (that is, they can and usually should be allocated separately).
  102. * But since the input pointers and the temporary buffer all require IPC
  103. * registration, they are allocated and registered together in the test for
  104. * convenience.
  105. */
  106. CUDACHECK(
  107. cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(aphrodite::Metadata)));
  108. CUDACHECK(cudaMemset(buffer, 0,
  109. 2 * data_size * sizeof(T) + sizeof(aphrodite::Metadata)));
  110. CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
  111. CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
  112. MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
  113. MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
  114. MPI_BYTE, MPI_COMM_WORLD));
  115. void *rank_data;
  116. size_t rank_data_sz = 16 * 1024 * 1024;
  117. CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
  118. std::vector<int64_t> offsets(nRanks, 0);
  119. aphrodite::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
  120. offsets, myRank);
  121. auto *self_data =
  122. reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
  123. sizeof(aphrodite::Metadata) + data_size * sizeof(T));
  124. // hack buffer registration
  125. {
  126. std::vector<std::string> handles;
  127. handles.reserve(nRanks);
  128. for (int i = 0; i < nRanks; i++) {
  129. char *begin = (char *)&data_handles[i];
  130. char *end = (char *)&data_handles[i + 1];
  131. handles.emplace_back(begin, end);
  132. }
  133. std::vector<int64_t> offsets(
  134. nRanks, sizeof(aphrodite::Metadata) + data_size * sizeof(T));
  135. fa.register_buffer(handles, offsets, self_data);
  136. }
  137. double *ground_truth;
  138. CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
  139. curandState_t *states;
  140. CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
  141. init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
  142. gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
  143. nRanks, data_size);
  144. CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
  145. cudaMemcpyDeviceToDevice, stream));
  146. cudaEvent_t start, stop;
  147. CUDACHECK(cudaEventCreate(&start));
  148. CUDACHECK(cudaEventCreate(&stop));
  149. ncclDataType_t ncclDtype;
  150. if (std::is_same<T, half>::value) {
  151. ncclDtype = ncclFloat16;
  152. } else if (std::is_same<T, nv_bfloat16>::value) {
  153. ncclDtype = ncclBfloat16;
  154. } else {
  155. ncclDtype = ncclFloat;
  156. }
  157. dummy_kernel<<<1, 1, 0, stream>>>();
  158. constexpr int warmup_iters = 5;
  159. constexpr int num_iters = 25;
  160. // warmup
  161. for (int i = 0; i < warmup_iters; i++) {
  162. NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
  163. stream));
  164. }
  165. CUDACHECK(cudaEventRecord(start, stream));
  166. for (int i = 0; i < num_iters; i++) {
  167. NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
  168. stream));
  169. }
  170. CUDACHECK(cudaEventRecord(stop, stream));
  171. CUDACHECK(cudaStreamSynchronize(stream));
  172. float allreduce_ms = 0;
  173. cudaEventElapsedTime(&allreduce_ms, start, stop);
  174. // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
  175. // set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
  176. dummy_kernel<<<1, 1, 0, stream>>>();
  177. // warm up
  178. for (int i = 0; i < warmup_iters; i++) {
  179. fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
  180. }
  181. CUDACHECK(cudaEventRecord(start, stream));
  182. for (int i = 0; i < num_iters; i++) {
  183. fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
  184. }
  185. CUDACHECK(cudaEventRecord(stop, stream));
  186. CUDACHECK(cudaStreamSynchronize(stream));
  187. float duration_ms = 0;
  188. cudaEventElapsedTime(&duration_ms, start, stop);
  189. if (myRank == 0)
  190. printf(
  191. "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
  192. "time:%.2fus\n",
  193. myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
  194. duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
  195. // And wait for all the queued up work to complete
  196. CUDACHECK(cudaStreamSynchronize(stream));
  197. NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
  198. ncclSum, comm, stream));
  199. double *nccl_result, *my_result;
  200. CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
  201. CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
  202. convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
  203. my_result, data_size);
  204. CUDACHECK(cudaStreamSynchronize(stream));
  205. for (unsigned long j = 0; j < data_size; j++) {
  206. auto diff = abs(nccl_result[j] - my_result[j]);
  207. if (diff >= 1e-2) {
  208. printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
  209. myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
  210. break;
  211. }
  212. }
  213. long double nccl_diffs = 0.0;
  214. long double my_diffs = 0.0;
  215. for (int j = 0; j < data_size; j++) {
  216. nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
  217. my_diffs += abs(my_result[j] - ground_truth[j]);
  218. }
  219. if (myRank == 0)
  220. std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
  221. << " me: " << my_diffs / data_size << std::endl;
  222. CUDACHECK(cudaFree(result));
  223. CUDACHECK(cudaFree(self_data_copy));
  224. CUDACHECK(cudaFree(rank_data));
  225. CUDACHECK(cudaFree(buffer));
  226. CUDACHECK(cudaFree(states));
  227. CUDACHECK(cudaFreeHost(ground_truth));
  228. CUDACHECK(cudaFreeHost(nccl_result));
  229. CUDACHECK(cudaFreeHost(my_result));
  230. CUDACHECK(cudaStreamDestroy(stream));
  231. }
  232. int main(int argc, char **argv) {
  233. int nRanks, myRank;
  234. MPICHECK(MPI_Init(&argc, &argv));
  235. MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
  236. MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
  237. CUDACHECK(cudaSetDevice(myRank));
  238. ncclUniqueId id;
  239. ncclComm_t comm;
  240. if (myRank == 0) ncclGetUniqueId(&id);
  241. MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
  242. MPI_COMM_WORLD));
  243. NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
  244. cudaProfilerStart();
  245. // for (int threads : {256, 512}) {
  246. // for (int block_limit = 16; block_limit < 112; block_limit += 4) {
  247. // run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
  248. // }
  249. // }
  250. for (int sz = 512; sz <= (32 << 20); sz *= 2) {
  251. run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
  252. }
  253. cudaProfilerStop();
  254. return EXIT_SUCCESS;
  255. }