custom_all_reduce.cuh 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_bf16.h>
  4. #include <cuda_fp16.h>
  5. #include <cuda_runtime.h>
  6. #include <iostream>
  7. #include <limits>
  8. #include <map>
  9. #include <unordered_map>
  10. #include <vector>
  11. #define CUDACHECK(cmd) \
  12. do { \
  13. cudaError_t e = cmd; \
  14. if (e != cudaSuccess) { \
  15. printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
  16. cudaGetErrorString(e)); \
  17. exit(EXIT_FAILURE); \
  18. } \
  19. } while (0)
  20. namespace aphrodite {
  21. struct Signal {
  22. alignas(64) union {
  23. uint64_t flag;
  24. unsigned char data[8];
  25. } start;
  26. alignas(64) union {
  27. uint64_t flag;
  28. unsigned char data[8];
  29. } end;
  30. };
  31. struct Metadata {
  32. alignas(128) Signal sg;
  33. alignas(128) int counter;
  34. };
  35. static_assert(offsetof(Metadata, counter) == 128);
  36. static_assert(sizeof(Metadata) == 256);
  37. struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
  38. struct RankSignals {
  39. volatile Signal *signals[8];
  40. };
  41. // like std::array, but aligned
  42. template <typename T, int sz>
  43. struct __align__(alignof(T) * sz) array_t {
  44. T data[sz];
  45. using type = T;
  46. static constexpr int size = sz;
  47. };
  48. // use packed type to maximize memory efficiency
  49. // goal: generate ld.128 and st.128 instructions
  50. template <typename T>
  51. struct packed_t {
  52. // the (P)acked type for load/store
  53. using P = array_t<T, 16 / sizeof(T)>;
  54. // the (A)ccumulator type for reduction
  55. using A = array_t<float, 16 / sizeof(T)>;
  56. };
  57. #define DINLINE __device__ __forceinline__
  58. // scalar cast functions
  59. DINLINE float upcast_s(half val) { return __half2float(val); }
  60. template <typename T>
  61. DINLINE T downcast_s(float val);
  62. template <>
  63. DINLINE half downcast_s(float val) {
  64. return __float2half(val);
  65. }
  66. // scalar add functions
  67. // for some reason when compiling with Pytorch, the + operator for half and
  68. // bfloat is disabled so we call the intrinsics directly
  69. DINLINE half &assign_add(half &a, half b) {
  70. a = __hadd(a, b);
  71. return a;
  72. }
  73. DINLINE float &assign_add(float &a, float b) { return a += b; }
  74. #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
  75. DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
  76. template <>
  77. DINLINE nv_bfloat16 downcast_s(float val) {
  78. return __float2bfloat16(val);
  79. }
  80. DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
  81. a = __hadd(a, b);
  82. return a;
  83. }
  84. #endif
  85. template <typename T, int N>
  86. DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
  87. #pragma unroll
  88. for (int i = 0; i < N; i++) {
  89. assign_add(a.data[i], b.data[i]);
  90. }
  91. return a;
  92. }
  93. template <typename T, int N>
  94. DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  95. if constexpr (std::is_same<T, float>::value) {
  96. return val;
  97. } else {
  98. array_t<float, N> out;
  99. #pragma unroll
  100. for (int i = 0; i < N; i++) {
  101. out.data[i] = upcast_s(val.data[i]);
  102. }
  103. return out;
  104. }
  105. }
  106. template <typename O>
  107. DINLINE O downcast(array_t<float, O::size> val) {
  108. if constexpr (std::is_same<typename O::type, float>::value) {
  109. return val;
  110. } else {
  111. O out;
  112. #pragma unroll
  113. for (int i = 0; i < O::size; i++) {
  114. out.data[i] = downcast_s<typename O::type>(val.data[i]);
  115. }
  116. return out;
  117. }
  118. }
  119. // compute flag at compile time
  120. __host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
  121. auto m = std::numeric_limits<uint64_t>::max();
  122. return m >> ((8 - ngpus) * 8);
  123. }
  124. template <int ngpus>
  125. DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
  126. int rank) {
  127. constexpr auto FLAG = compute_flag(ngpus);
  128. if (blockIdx.x == 0) {
  129. if (threadIdx.x < ngpus)
  130. // simultaneously write to the corresponding byte to all other ranks.
  131. // Latency = 1 p2p write
  132. sg.signals[threadIdx.x]->start.data[rank] = 255;
  133. else if (threadIdx.x == 32)
  134. // reset
  135. meta->sg.end.flag = 0;
  136. }
  137. if (threadIdx.x == 0) {
  138. while (meta->sg.start.flag != FLAG)
  139. ;
  140. }
  141. __syncthreads();
  142. }
  143. template <int ngpus, bool final_sync = false>
  144. DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
  145. int rank) {
  146. constexpr auto FLAG = compute_flag(ngpus);
  147. __syncthreads();
  148. __shared__ int num;
  149. if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
  150. __syncthreads();
  151. // Only the last completing block can perform the end synchronization
  152. // This can ensures when the final busy wait ends, all ranks must have
  153. // finished reading each other's buffer.
  154. if (num == gridDim.x - 1) {
  155. if (threadIdx.x == 32) {
  156. // reset in a different warp
  157. meta->counter = 0;
  158. meta->sg.start.flag = 0;
  159. } else if (threadIdx.x < ngpus) {
  160. // simultaneously write to the corresponding byte to all other ranks.
  161. // Latency = 1 p2p write
  162. sg.signals[threadIdx.x]->end.data[rank] = 255;
  163. }
  164. // if this is the final sync, only one block needs it
  165. // because kernel exit can serve as sync
  166. if constexpr (final_sync) {
  167. if (threadIdx.x == 0) {
  168. while (meta->sg.end.flag != FLAG)
  169. ;
  170. }
  171. }
  172. }
  173. if constexpr (!final_sync) {
  174. if (threadIdx.x == 0) {
  175. while (meta->sg.end.flag != FLAG)
  176. ;
  177. }
  178. __syncthreads();
  179. }
  180. }
  181. template <typename P, int ngpus, typename A>
  182. DINLINE P packed_reduce(const P *ptrs[], int idx) {
  183. A tmp = upcast(ptrs[0][idx]);
  184. #pragma unroll
  185. for (int i = 1; i < ngpus; i++) {
  186. packed_assign_add(tmp, upcast(ptrs[i][idx]));
  187. }
  188. return downcast<P>(tmp);
  189. }
  190. template <typename T, int ngpus>
  191. __global__ void __launch_bounds__(512, 1)
  192. cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
  193. volatile Metadata *meta, T *__restrict__ result,
  194. int rank, int size) {
  195. using P = typename packed_t<T>::P;
  196. using A = typename packed_t<T>::A;
  197. // note: we don't reorder the address so the accumulation order is the same
  198. // for all ranks, ensuring bitwise identical results
  199. auto dp = *_dp;
  200. start_sync<ngpus>(sg, meta, rank);
  201. // do the actual reduction
  202. for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
  203. idx += gridDim.x * blockDim.x) {
  204. ((P *)result)[idx] =
  205. packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
  206. }
  207. end_sync<ngpus, true>(sg, meta, rank);
  208. }
  209. template <typename P>
  210. DINLINE P *get_tmp_buf(volatile Signal *sg) {
  211. return (P *)(((Metadata *)sg) + 1);
  212. }
  213. template <typename T, int ngpus>
  214. __global__ void __launch_bounds__(512, 1)
  215. cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
  216. volatile Metadata *meta, T *__restrict__ result,
  217. int rank, int size) {
  218. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  219. int stride = gridDim.x * blockDim.x;
  220. using P = typename packed_t<T>::P;
  221. using A = typename packed_t<T>::A;
  222. int part = size / ngpus;
  223. int start = rank * part;
  224. int end = rank == ngpus - 1 ? size : start + part;
  225. const P *ptrs[ngpus];
  226. P *tmps[ngpus];
  227. #pragma unroll
  228. for (int i = 0; i < ngpus; i++) {
  229. int target = (rank + i) % ngpus;
  230. ptrs[i] = (const P *)_dp->ptrs[target];
  231. tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  232. }
  233. auto tmp_out = tmps[0];
  234. start_sync<ngpus>(sg, meta, rank);
  235. // stage 1: reduce scatter
  236. for (int idx = start + tid; idx < end; idx += stride) {
  237. tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  238. }
  239. // Maybe TODO: replace this with per-block release-acquire
  240. // can save about 1-2us (not a lot though)
  241. end_sync<ngpus>(sg, meta, rank);
  242. // stage 2: allgather
  243. for (int idx = tid; idx < part; idx += stride) {
  244. #pragma unroll
  245. for (int i = 0; i < ngpus; i++) {
  246. int dst_idx = ((rank + i) % ngpus) * part + idx;
  247. ((P *)result)[dst_idx] = tmps[i][idx];
  248. }
  249. }
  250. // process the last larger partition
  251. int remaining = size - part * ngpus;
  252. if (tid < remaining) {
  253. int dst_idx = tid + part * ngpus;
  254. ((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
  255. }
  256. // faster than this
  257. // for (int idx = tid; idx < size; idx += stride) {
  258. // int target_rank = idx / part;
  259. // if (target_rank == ngpus) target_rank -= 1;
  260. // ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
  261. // }
  262. }
  263. template <typename T, int ngpus>
  264. __global__ void __launch_bounds__(512, 1)
  265. cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
  266. volatile Metadata *meta,
  267. T *__restrict__ result, int rank,
  268. int size) {
  269. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  270. int stride = gridDim.x * blockDim.x;
  271. using P = typename packed_t<T>::P;
  272. using A = typename packed_t<T>::A;
  273. auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
  274. constexpr int hg = ngpus / 2;
  275. // Actually not quite half butterfly.
  276. // This is an all-to-all within each group containing half of the ranks
  277. // followed by cross-group add. Equivalent to half butterfly when there
  278. // are 4 GPUs, a common case for PCIe cards like T4 and A10.
  279. const P *ptrs[hg];
  280. {
  281. int start = rank - rank % hg;
  282. #pragma unroll
  283. for (int i = 0; i < hg; i++) {
  284. ptrs[i] = (const P *)_dp->ptrs[i + start];
  285. }
  286. }
  287. start_sync<ngpus>(sg, meta, rank);
  288. for (int idx = tid; idx < size; idx += stride) {
  289. tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
  290. }
  291. end_sync<ngpus>(sg, meta, rank);
  292. auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
  293. // do the cross group reduction
  294. for (int idx = tid; idx < size; idx += stride) {
  295. auto tmp = tmp_out[idx];
  296. packed_assign_add(tmp, src[idx]);
  297. ((P *)result)[idx] = tmp;
  298. }
  299. }
  300. using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
  301. static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
  302. static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
  303. class CustomAllreduce {
  304. public:
  305. int rank_;
  306. int world_size_;
  307. bool full_nvlink_;
  308. // below are device pointers
  309. RankSignals sg_;
  310. std::unordered_map<void *, RankData *> buffers_;
  311. Metadata *meta_;
  312. // stores the registered device pointers from all ranks
  313. RankData *d_rank_data_base_, *d_rank_data_end_;
  314. std::vector<void *> graph_unreg_buffers_;
  315. // a map from IPC handles to opened IPC pointers
  316. std::map<IPC_KEY, char *> ipc_handles_;
  317. /**
  318. * meta is a pointer to device metadata and temporary buffer for allreduce.
  319. *
  320. * There's a total of sizeof(Metadata) of prefix before the actual data,
  321. * so meta + 1 points to actual temporary buffer.
  322. *
  323. * note: this class does not own any device memory. Any required buffers
  324. * are passed in from the constructor
  325. */
  326. CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
  327. const cudaIpcMemHandle_t *handles,
  328. const std::vector<int64_t> &offsets, int rank,
  329. bool full_nvlink = true)
  330. : rank_(rank),
  331. world_size_(offsets.size()),
  332. full_nvlink_(full_nvlink),
  333. meta_(meta),
  334. d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
  335. d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
  336. for (int i = 0; i < world_size_; i++) {
  337. Metadata *rank_meta;
  338. if (i != rank_) {
  339. char *handle = open_ipc_handle(&handles[i]);
  340. handle += offsets[i];
  341. rank_meta = (Metadata *)handle;
  342. } else {
  343. rank_meta = meta_;
  344. }
  345. sg_.signals[i] = &rank_meta->sg;
  346. }
  347. }
  348. char *open_ipc_handle(const void *ipc_handle) {
  349. auto [it, new_handle] =
  350. ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
  351. if (new_handle) {
  352. char *ipc_ptr;
  353. CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
  354. *((const cudaIpcMemHandle_t *)ipc_handle),
  355. cudaIpcMemLazyEnablePeerAccess));
  356. it->second = ipc_ptr;
  357. }
  358. return it->second;
  359. }
  360. std::pair<std::vector<uint8_t>, std::vector<int64_t>>
  361. get_graph_buffer_ipc_meta() {
  362. auto num_buffers = graph_unreg_buffers_.size();
  363. auto handle_sz = sizeof(cudaIpcMemHandle_t);
  364. std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
  365. std::vector<int64_t> offsets(num_buffers);
  366. for (int i = 0; i < num_buffers; i++) {
  367. auto ptr = graph_unreg_buffers_[i];
  368. void *base_ptr;
  369. // note: must share the base address of each allocation, or we get wrong
  370. // address
  371. if (cuPointerGetAttribute(&base_ptr,
  372. CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
  373. (CUdeviceptr)ptr) != CUDA_SUCCESS)
  374. throw std::runtime_error("failed to get pointer attr");
  375. CUDACHECK(cudaIpcGetMemHandle(
  376. (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
  377. offsets[i] = ((char *)ptr) - ((char *)base_ptr);
  378. }
  379. return std::make_pair(handles, offsets);
  380. }
  381. void check_rank_data_capacity(size_t num = 1) {
  382. if (d_rank_data_base_ + num > d_rank_data_end_)
  383. throw std::runtime_error(
  384. "Rank data buffer is overflowed by " +
  385. std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  386. }
  387. void register_buffer(const std::vector<std::string> &handles,
  388. const std::vector<int64_t> &offsets, void *self) {
  389. check_rank_data_capacity();
  390. RankData data;
  391. for (int i = 0; i < world_size_; i++) {
  392. if (i != rank_) {
  393. char *handle = open_ipc_handle(handles[i].data());
  394. handle += offsets[i];
  395. data.ptrs[i] = handle;
  396. } else {
  397. data.ptrs[i] = self;
  398. }
  399. }
  400. auto d_data = d_rank_data_base_++;
  401. CUDACHECK(
  402. cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
  403. buffers_[self] = d_data;
  404. }
  405. // note: when registering graph buffers, we intentionally choose to not
  406. // deduplicate the addresses. That means if the allocator reuses some
  407. // addresses, they will be registered again. This is to account for the remote
  408. // possibility of different allocation patterns between ranks. For example,
  409. // rank 1 may get the same input address for the second allreduce, but rank 2
  410. // got a different address. IPC handles have internal reference counting
  411. // mechanism so overhead should be small.
  412. void register_graph_buffers(
  413. const std::vector<std::string> &handles,
  414. const std::vector<std::vector<int64_t>> &offsets) {
  415. auto num_buffers = graph_unreg_buffers_.size();
  416. check_rank_data_capacity(num_buffers);
  417. std::vector<RankData> rank_data(num_buffers);
  418. for (int i = 0; i < num_buffers; i++) {
  419. auto self_ptr = graph_unreg_buffers_[i];
  420. auto &rd = rank_data[i];
  421. for (int j = 0; j < world_size_; j++) {
  422. if (j != rank_) {
  423. char *handle =
  424. open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
  425. handle += offsets[j][i];
  426. rd.ptrs[j] = handle;
  427. } else {
  428. rd.ptrs[j] = self_ptr;
  429. }
  430. }
  431. }
  432. CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
  433. sizeof(RankData) * num_buffers,
  434. cudaMemcpyHostToDevice));
  435. d_rank_data_base_ += num_buffers;
  436. graph_unreg_buffers_.clear();
  437. }
  438. /**
  439. * This is the result after careful grid search. Using 36 blocks give the best
  440. * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
  441. * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
  442. * Not quite sure the underlying reason, but my guess is that too many SMs
  443. * will cause contention on NVLink bus.
  444. */
  445. template <typename T>
  446. void allreduce(cudaStream_t stream, T *input, T *output, int size,
  447. int threads = 512, int block_limit = 36) {
  448. auto d = packed_t<T>::P::size;
  449. if (size % d != 0)
  450. throw std::runtime_error(
  451. "custom allreduce currently requires input length to be multiple "
  452. "of " +
  453. std::to_string(d));
  454. RankData *ptrs;
  455. cudaStreamCaptureStatus status;
  456. CUDACHECK(cudaStreamIsCapturing(stream, &status));
  457. if (status == cudaStreamCaptureStatusActive) {
  458. ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
  459. graph_unreg_buffers_.push_back(input);
  460. } else {
  461. auto it = buffers_.find(input);
  462. if (it == buffers_.end())
  463. throw std::runtime_error(
  464. "buffer address " +
  465. std::to_string(reinterpret_cast<uint64_t>(input)) +
  466. " is not registered!");
  467. ptrs = it->second;
  468. }
  469. size /= d;
  470. auto bytes = size * sizeof(typename packed_t<T>::P);
  471. int blocks = std::min(block_limit, (size + threads - 1) / threads);
  472. #define KL(ngpus, name) \
  473. name<T, ngpus> \
  474. <<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
  475. #define REDUCE_CASE(ngpus) \
  476. case ngpus: { \
  477. if (world_size_ == 2) { \
  478. KL(ngpus, cross_device_reduce_1stage); \
  479. } else if (full_nvlink_) { \
  480. if ((world_size_ <= 4 && bytes < 512 * 1024) || \
  481. (world_size_ <= 8 && bytes < 256 * 1024)) { \
  482. KL(ngpus, cross_device_reduce_1stage); \
  483. } else { \
  484. KL(ngpus, cross_device_reduce_2stage); \
  485. } \
  486. } else { \
  487. KL(ngpus, cross_device_reduce_half_butterfly); \
  488. } \
  489. break; \
  490. }
  491. switch (world_size_) {
  492. REDUCE_CASE(2)
  493. REDUCE_CASE(4)
  494. REDUCE_CASE(6)
  495. REDUCE_CASE(8)
  496. default:
  497. throw std::runtime_error(
  498. "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
  499. "gpus = " +
  500. std::to_string(world_size_));
  501. }
  502. #undef REDUCE_CASE
  503. #undef KL
  504. }
  505. ~CustomAllreduce() {
  506. for (auto [_, ptr] : ipc_handles_) {
  507. CUDACHECK(cudaIpcCloseMemHandle(ptr));
  508. }
  509. }
  510. };
  511. /**
  512. * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
  513. a template instantiation:
  514. * template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
  515. int, int, int);
  516. */
  517. } // namespace aphrodite