custom_all_reduce.cuh 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_bf16.h>
  4. #include <cuda_fp16.h>
  5. #include <cuda_runtime.h>
  6. #include <iostream>
  7. #include <limits>
  8. #include <map>
  9. #include <unordered_map>
  10. #include <vector>
  11. #define CUDACHECK(cmd) \
  12. do { \
  13. cudaError_t e = cmd; \
  14. if (e != cudaSuccess) { \
  15. printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
  16. cudaGetErrorString(e)); \
  17. exit(EXIT_FAILURE); \
  18. } \
  19. } while (0)
  20. namespace aphrodite {
  21. constexpr int kMaxBlocks = 64;
  22. // note: we don't want to use atomics for signals because peer atomics are no
  23. // supported on PCIe links
  24. struct Signal {
  25. alignas(128) uint32_t start[kMaxBlocks][8];
  26. alignas(128) uint32_t end[kMaxBlocks][8];
  27. };
  28. struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
  29. struct __align__(16) RankSignals { volatile Signal *signals[8]; };
  30. // like std::array, but aligned
  31. template <typename T, int sz>
  32. struct __align__(alignof(T) * sz) array_t {
  33. T data[sz];
  34. using type = T;
  35. static constexpr int size = sz;
  36. };
  37. // use packed type to maximize memory efficiency
  38. // goal: generate ld.128 and st.128 instructions
  39. template <typename T>
  40. struct packed_t {
  41. // the (P)acked type for load/store
  42. using P = array_t<T, 16 / sizeof(T)>;
  43. // the (A)ccumulator type for reduction
  44. using A = array_t<float, 16 / sizeof(T)>;
  45. };
  46. #define DINLINE __device__ __forceinline__
  47. // scalar cast functions
  48. DINLINE float upcast_s(half val) { return __half2float(val); }
  49. template <typename T>
  50. DINLINE T downcast_s(float val);
  51. template <>
  52. DINLINE half downcast_s(float val) {
  53. return __float2half(val);
  54. }
  55. // scalar add functions
  56. // for some reason when compiling with Pytorch, the + operator for half and
  57. // bfloat is disabled so we call the intrinsics directly
  58. DINLINE half &assign_add(half &a, half b) {
  59. a = __hadd(a, b);
  60. return a;
  61. }
  62. DINLINE float &assign_add(float &a, float b) { return a += b; }
  63. #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
  64. DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
  65. template <>
  66. DINLINE nv_bfloat16 downcast_s(float val) {
  67. return __float2bfloat16(val);
  68. }
  69. DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
  70. a = __hadd(a, b);
  71. return a;
  72. }
  73. #endif
  74. template <typename T, int N>
  75. DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
  76. #pragma unroll
  77. for (int i = 0; i < N; i++) {
  78. assign_add(a.data[i], b.data[i]);
  79. }
  80. return a;
  81. }
  82. template <typename T, int N>
  83. DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  84. if constexpr (std::is_same<T, float>::value) {
  85. return val;
  86. } else {
  87. array_t<float, N> out;
  88. #pragma unroll
  89. for (int i = 0; i < N; i++) {
  90. out.data[i] = upcast_s(val.data[i]);
  91. }
  92. return out;
  93. }
  94. }
  95. template <typename O>
  96. DINLINE O downcast(array_t<float, O::size> val) {
  97. if constexpr (std::is_same<typename O::type, float>::value) {
  98. return val;
  99. } else {
  100. O out;
  101. #pragma unroll
  102. for (int i = 0; i < O::size; i++) {
  103. out.data[i] = downcast_s<typename O::type>(val.data[i]);
  104. }
  105. return out;
  106. }
  107. }
  108. // This function is meant to be used as the first synchronization in the all
  109. // reduce kernel. Thus, it doesn't need to make any visibility guarantees for
  110. // prior memory accesses. Note: volatile writes will not be reordered against
  111. // other volatile writes.
  112. template <int ngpus>
  113. DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
  114. int rank) {
  115. if (threadIdx.x < ngpus) {
  116. // reset flag for next time
  117. self_sg->end[blockIdx.x][threadIdx.x] = 0;
  118. // simultaneously write to the corresponding flag of all ranks.
  119. // Latency = 1 p2p write
  120. sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
  121. // wait until we got true from all ranks
  122. while (!self_sg->start[blockIdx.x][threadIdx.x])
  123. ;
  124. }
  125. __syncthreads();
  126. }
  127. // This function is meant to be used as the second or the final synchronization
  128. // barrier in the all reduce kernel. If it's the final synchronization barrier,
  129. // we don't need to make any visibility guarantees for prior memory accesses.
  130. template <int ngpus, bool final_sync = false>
  131. DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
  132. int rank) {
  133. __syncthreads();
  134. // eliminate the case that prior writes are not visible after signals become
  135. // visible. Note that I did not managed to make this happen through a lot of
  136. // testing. Might be the case that hardware provides stronger guarantee than
  137. // the memory model.
  138. if constexpr (!final_sync) __threadfence_system();
  139. if (threadIdx.x < ngpus) {
  140. // reset flag for next time
  141. self_sg->start[blockIdx.x][threadIdx.x] = 0;
  142. // simultaneously write to the corresponding flag of all ranks.
  143. // Latency = 1 p2p write
  144. sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
  145. // wait until we got true from all ranks
  146. while (!self_sg->end[blockIdx.x][threadIdx.x])
  147. ;
  148. }
  149. if constexpr (!final_sync) __syncthreads();
  150. }
  151. template <typename P, int ngpus, typename A>
  152. DINLINE P packed_reduce(const P *ptrs[], int idx) {
  153. A tmp = upcast(ptrs[0][idx]);
  154. #pragma unroll
  155. for (int i = 1; i < ngpus; i++) {
  156. packed_assign_add(tmp, upcast(ptrs[i][idx]));
  157. }
  158. return downcast<P>(tmp);
  159. }
  160. template <typename T, int ngpus>
  161. __global__ void __launch_bounds__(512, 1)
  162. cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
  163. volatile Signal *self_sg, T *__restrict__ result,
  164. int rank, int size) {
  165. using P = typename packed_t<T>::P;
  166. using A = typename packed_t<T>::A;
  167. // note: we don't reorder the address so the accumulation order is the same
  168. // for all ranks, ensuring bitwise identical results
  169. auto dp = *_dp;
  170. start_sync<ngpus>(sg, self_sg, rank);
  171. // do the actual reduction
  172. for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
  173. idx += gridDim.x * blockDim.x) {
  174. ((P *)result)[idx] =
  175. packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
  176. }
  177. end_sync<ngpus, true>(sg, self_sg, rank);
  178. }
  179. template <typename P>
  180. DINLINE P *get_tmp_buf(volatile Signal *sg) {
  181. return (P *)(((Signal *)sg) + 1);
  182. }
  183. template <typename T, int ngpus>
  184. __global__ void __launch_bounds__(512, 1)
  185. cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
  186. volatile Signal *self_sg, T *__restrict__ result,
  187. int rank, int size) {
  188. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  189. int stride = gridDim.x * blockDim.x;
  190. using P = typename packed_t<T>::P;
  191. using A = typename packed_t<T>::A;
  192. int part = size / ngpus;
  193. int start = rank * part;
  194. int end = rank == ngpus - 1 ? size : start + part;
  195. int largest_part = part + size % ngpus;
  196. const P *ptrs[ngpus];
  197. P *tmps[ngpus];
  198. #pragma unroll
  199. for (int i = 0; i < ngpus; i++) {
  200. int target = (rank + i) % ngpus;
  201. ptrs[i] = (const P *)_dp->ptrs[target];
  202. tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  203. }
  204. auto tmp_out = tmps[0];
  205. start_sync<ngpus>(sg, self_sg, rank);
  206. // stage 1: reduce scatter
  207. for (int idx = start + tid; idx < end; idx += stride) {
  208. tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  209. }
  210. end_sync<ngpus>(sg, self_sg, rank);
  211. // stage 2: allgather. Note: it's important to match the tid between
  212. // the two stages, because visibility across devices is only guaranteed
  213. // between threads that have the same tid. If thread i computes the sum of
  214. // start + i in the first stage, then thread i also gathers start + i from all
  215. // ranks.
  216. for (int idx = tid; idx < largest_part; idx += stride) {
  217. #pragma unroll
  218. for (int i = 0; i < ngpus; i++) {
  219. int gather_from_rank = ((rank + i) % ngpus);
  220. if (gather_from_rank == ngpus - 1 || idx < part) {
  221. int dst_idx = gather_from_rank * part + idx;
  222. ((P *)result)[dst_idx] = tmps[i][idx];
  223. }
  224. }
  225. }
  226. }
  227. using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
  228. static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
  229. static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
  230. class CustomAllreduce {
  231. public:
  232. int rank_;
  233. int world_size_;
  234. bool full_nvlink_;
  235. // below are device pointers
  236. RankSignals sg_;
  237. std::unordered_map<void *, RankData *> buffers_;
  238. Signal *self_sg_;
  239. // stores the registered device pointers from all ranks
  240. RankData *d_rank_data_base_, *d_rank_data_end_;
  241. std::vector<void *> graph_unreg_buffers_;
  242. // a map from IPC handles to opened IPC pointers
  243. std::map<IPC_KEY, char *> ipc_handles_;
  244. /**
  245. * meta is a pointer to device metadata and temporary buffer for allreduce.
  246. *
  247. * There's a total of sizeof(Signal) of prefix before the actual data,
  248. * so meta + 1 points to actual temporary buffer.
  249. *
  250. * note: this class does not own any device memory. Any required buffers
  251. * are passed in from the constructor
  252. */
  253. CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
  254. const cudaIpcMemHandle_t *handles,
  255. const std::vector<int64_t> &offsets, int rank,
  256. bool full_nvlink = true)
  257. : rank_(rank),
  258. world_size_(offsets.size()),
  259. full_nvlink_(full_nvlink),
  260. self_sg_(meta),
  261. d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
  262. d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
  263. for (int i = 0; i < world_size_; i++) {
  264. Signal *rank_sg;
  265. if (i != rank_) {
  266. char *handle = open_ipc_handle(&handles[i]);
  267. handle += offsets[i];
  268. rank_sg = (Signal *)handle;
  269. } else {
  270. rank_sg = self_sg_;
  271. }
  272. sg_.signals[i] = rank_sg;
  273. }
  274. }
  275. char *open_ipc_handle(const void *ipc_handle) {
  276. auto [it, new_handle] =
  277. ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
  278. if (new_handle) {
  279. char *ipc_ptr;
  280. CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
  281. *((const cudaIpcMemHandle_t *)ipc_handle),
  282. cudaIpcMemLazyEnablePeerAccess));
  283. it->second = ipc_ptr;
  284. }
  285. return it->second;
  286. }
  287. std::pair<std::vector<uint8_t>, std::vector<int64_t>>
  288. get_graph_buffer_ipc_meta() {
  289. auto num_buffers = graph_unreg_buffers_.size();
  290. auto handle_sz = sizeof(cudaIpcMemHandle_t);
  291. std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
  292. std::vector<int64_t> offsets(num_buffers);
  293. for (int i = 0; i < num_buffers; i++) {
  294. auto ptr = graph_unreg_buffers_[i];
  295. void *base_ptr;
  296. // note: must share the base address of each allocation, or we get wrong
  297. // address
  298. if (cuPointerGetAttribute(&base_ptr,
  299. CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
  300. (CUdeviceptr)ptr) != CUDA_SUCCESS)
  301. throw std::runtime_error("failed to get pointer attr");
  302. CUDACHECK(cudaIpcGetMemHandle(
  303. (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
  304. offsets[i] = ((char *)ptr) - ((char *)base_ptr);
  305. }
  306. return std::make_pair(handles, offsets);
  307. }
  308. void check_rank_data_capacity(size_t num = 1) {
  309. if (d_rank_data_base_ + num > d_rank_data_end_)
  310. throw std::runtime_error(
  311. "Rank data buffer is overflowed by " +
  312. std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  313. }
  314. void register_buffer(const std::vector<std::string> &handles,
  315. const std::vector<int64_t> &offsets, void *self) {
  316. check_rank_data_capacity();
  317. RankData data;
  318. for (int i = 0; i < world_size_; i++) {
  319. if (i != rank_) {
  320. char *handle = open_ipc_handle(handles[i].data());
  321. handle += offsets[i];
  322. data.ptrs[i] = handle;
  323. } else {
  324. data.ptrs[i] = self;
  325. }
  326. }
  327. auto d_data = d_rank_data_base_++;
  328. CUDACHECK(
  329. cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
  330. buffers_[self] = d_data;
  331. }
  332. // note: when registering graph buffers, we intentionally choose to not
  333. // deduplicate the addresses. That means if the allocator reuses some
  334. // addresses, they will be registered again. This is to account for the remote
  335. // possibility of different allocation patterns between ranks. For example,
  336. // rank 1 may get the same input address for the second allreduce, but rank 2
  337. // got a different address. IPC handles have internal reference counting
  338. // mechanism so overhead should be small.
  339. void register_graph_buffers(
  340. const std::vector<std::string> &handles,
  341. const std::vector<std::vector<int64_t>> &offsets) {
  342. auto num_buffers = graph_unreg_buffers_.size();
  343. check_rank_data_capacity(num_buffers);
  344. std::vector<RankData> rank_data(num_buffers);
  345. for (int i = 0; i < num_buffers; i++) {
  346. auto self_ptr = graph_unreg_buffers_[i];
  347. auto &rd = rank_data[i];
  348. for (int j = 0; j < world_size_; j++) {
  349. if (j != rank_) {
  350. char *handle =
  351. open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
  352. handle += offsets[j][i];
  353. rd.ptrs[j] = handle;
  354. } else {
  355. rd.ptrs[j] = self_ptr;
  356. }
  357. }
  358. }
  359. CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
  360. sizeof(RankData) * num_buffers,
  361. cudaMemcpyHostToDevice));
  362. d_rank_data_base_ += num_buffers;
  363. graph_unreg_buffers_.clear();
  364. }
  365. /**
  366. * This is the result after careful grid search. Using 36 blocks give the best
  367. * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
  368. * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
  369. * Not quite sure the underlying reason, but my guess is that too many SMs
  370. * will cause contention on NVLink bus.
  371. */
  372. template <typename T>
  373. void allreduce(cudaStream_t stream, T *input, T *output, int size,
  374. int threads = 512, int block_limit = 36) {
  375. auto d = packed_t<T>::P::size;
  376. if (size % d != 0)
  377. throw std::runtime_error(
  378. "custom allreduce currently requires input length to be multiple "
  379. "of " +
  380. std::to_string(d));
  381. if (block_limit > kMaxBlocks)
  382. throw std::runtime_error("max supported block limit is " +
  383. std::to_string(kMaxBlocks) + ". Got " +
  384. std::to_string(block_limit));
  385. RankData *ptrs;
  386. cudaStreamCaptureStatus status;
  387. CUDACHECK(cudaStreamIsCapturing(stream, &status));
  388. if (status == cudaStreamCaptureStatusActive) {
  389. ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
  390. graph_unreg_buffers_.push_back(input);
  391. } else {
  392. auto it = buffers_.find(input);
  393. if (it == buffers_.end())
  394. throw std::runtime_error(
  395. "buffer address " +
  396. std::to_string(reinterpret_cast<uint64_t>(input)) +
  397. " is not registered!");
  398. ptrs = it->second;
  399. }
  400. size /= d;
  401. auto bytes = size * sizeof(typename packed_t<T>::P);
  402. int blocks = std::min(block_limit, (size + threads - 1) / threads);
  403. #define KL(ngpus, name) \
  404. name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
  405. rank_, size);
  406. #define REDUCE_CASE(ngpus) \
  407. case ngpus: { \
  408. if (world_size_ == 2) { \
  409. KL(ngpus, cross_device_reduce_1stage); \
  410. } else if (full_nvlink_) { \
  411. if ((world_size_ <= 4 && bytes < 512 * 1024) || \
  412. (world_size_ <= 8 && bytes < 256 * 1024)) { \
  413. KL(ngpus, cross_device_reduce_1stage); \
  414. } else { \
  415. KL(ngpus, cross_device_reduce_2stage); \
  416. } \
  417. } \
  418. break; \
  419. }
  420. switch (world_size_) {
  421. REDUCE_CASE(2)
  422. REDUCE_CASE(4)
  423. REDUCE_CASE(6)
  424. REDUCE_CASE(8)
  425. default:
  426. throw std::runtime_error(
  427. "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
  428. "gpus = " +
  429. std::to_string(world_size_));
  430. }
  431. #undef REDUCE_CASE
  432. #undef KL
  433. }
  434. ~CustomAllreduce() {
  435. for (auto [_, ptr] : ipc_handles_) {
  436. CUDACHECK(cudaIpcCloseMemHandle(ptr));
  437. }
  438. }
  439. };
  440. /**
  441. * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
  442. a template instantiation:
  443. * template void aphrodite::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
  444. half *, int, int, int);
  445. */
  446. } // namespace aphrodite