custom_all_reduce.cuh 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_bf16.h>
  4. #include <cuda_fp16.h>
  5. #include <cuda_runtime.h>
  6. #include <iostream>
  7. #include <array>
  8. #include <limits>
  9. #include <map>
  10. #include <unordered_map>
  11. #include <vector>
  12. #define CUDACHECK(cmd) \
  13. do { \
  14. cudaError_t e = cmd; \
  15. if (e != cudaSuccess) { \
  16. printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
  17. cudaGetErrorString(e)); \
  18. exit(EXIT_FAILURE); \
  19. } \
  20. } while (0)
  21. namespace aphrodite {
  22. constexpr int kMaxBlocks = 36;
  23. // Counter may overflow, but it's fine since unsigned int overflow is
  24. // well-defined behavior.
  25. using FlagType = uint32_t;
  26. struct Signal {
  27. alignas(128) FlagType self_counter[kMaxBlocks][8];
  28. // Two sets of peer counters are needed for two syncs. The reason is that
  29. // it's possible for peer GPU block to arrive at the second sync point while
  30. // the current GPU block haven't passed the first sync point. Thus, peer GPU
  31. // may write counter+1 while current GPU is busy waiting for counter. We use
  32. // alternating counter array to avoid this possibility.
  33. alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
  34. };
  35. struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
  36. struct __align__(16) RankSignals { Signal* signals[8]; };
  37. // like std::array, but aligned
  38. template <typename T, int sz>
  39. struct __align__(alignof(T) * sz) array_t {
  40. T data[sz];
  41. using type = T;
  42. static constexpr int size = sz;
  43. };
  44. // use packed type to maximize memory efficiency
  45. // goal: generate ld.128 and st.128 instructions
  46. template <typename T>
  47. struct packed_t {
  48. // the (P)acked type for load/store
  49. using P = array_t<T, 16 / sizeof(T)>;
  50. // the (A)ccumulator type for reduction
  51. using A = array_t<float, 16 / sizeof(T)>;
  52. };
  53. #define DINLINE __device__ __forceinline__
  54. // scalar cast functions
  55. DINLINE float upcast_s(half val) { return __half2float(val); }
  56. template <typename T>
  57. DINLINE T downcast_s(float val);
  58. template <>
  59. DINLINE half downcast_s(float val) {
  60. return __float2half(val);
  61. }
  62. // scalar add functions
  63. // for some reason when compiling with Pytorch, the + operator for half and
  64. // bfloat is disabled so we call the intrinsics directly
  65. DINLINE half& assign_add(half& a, half b) {
  66. a = __hadd(a, b);
  67. return a;
  68. }
  69. DINLINE float& assign_add(float& a, float b) { return a += b; }
  70. #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
  71. DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
  72. template <>
  73. DINLINE nv_bfloat16 downcast_s(float val) {
  74. return __float2bfloat16(val);
  75. }
  76. DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
  77. a = __hadd(a, b);
  78. return a;
  79. }
  80. #endif
  81. template <typename T, int N>
  82. DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
  83. #pragma unroll
  84. for (int i = 0; i < N; i++) {
  85. assign_add(a.data[i], b.data[i]);
  86. }
  87. return a;
  88. }
  89. template <typename T, int N>
  90. DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  91. if constexpr (std::is_same<T, float>::value) {
  92. return val;
  93. } else {
  94. array_t<float, N> out;
  95. #pragma unroll
  96. for (int i = 0; i < N; i++) {
  97. out.data[i] = upcast_s(val.data[i]);
  98. }
  99. return out;
  100. }
  101. }
  102. template <typename O>
  103. DINLINE O downcast(array_t<float, O::size> val) {
  104. if constexpr (std::is_same<typename O::type, float>::value) {
  105. return val;
  106. } else {
  107. O out;
  108. #pragma unroll
  109. for (int i = 0; i < O::size; i++) {
  110. out.data[i] = downcast_s<typename O::type>(val.data[i]);
  111. }
  112. return out;
  113. }
  114. }
  115. static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
  116. asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
  117. "l"(flag_addr));
  118. }
  119. static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
  120. FlagType flag;
  121. asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
  122. : "=r"(flag)
  123. : "l"(flag_addr));
  124. return flag;
  125. }
  126. static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
  127. asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
  128. }
  129. static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
  130. FlagType flag;
  131. asm volatile("ld.volatile.global.u32 %0, [%1];"
  132. : "=r"(flag)
  133. : "l"(flag_addr));
  134. return flag;
  135. }
  136. // is_start: whether this is the very first synchronization barrier.
  137. // need_fence: whether a memory fence is needed. If true, a release-acquire
  138. // semantic is used to enforce memory access order before and after this
  139. // barrier.
  140. template <int ngpus, bool is_start, bool need_fence = false>
  141. DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
  142. int rank) {
  143. if constexpr (!is_start) __syncthreads();
  144. static_assert(
  145. !(is_start && need_fence)); // Start barrier shouldn't need fence.
  146. if (threadIdx.x < ngpus) {
  147. // Increment the counter. Technically we only need one counter, but we use
  148. // multiple per block to eliminate the need to share the counter via smem.
  149. auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
  150. // Write the expected counter value to peer and wait for correct value from
  151. // peer.
  152. auto peer_counter_ptr =
  153. &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
  154. auto self_counter_ptr =
  155. &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
  156. if constexpr (need_fence) {
  157. st_flag_release(peer_counter_ptr, val);
  158. while (ld_flag_acquire(self_counter_ptr) != val);
  159. } else {
  160. st_flag_volatile(peer_counter_ptr, val);
  161. while (ld_flag_volatile(self_counter_ptr) != val);
  162. }
  163. }
  164. if constexpr (is_start || need_fence) __syncthreads();
  165. }
  166. template <typename P, int ngpus, typename A>
  167. DINLINE P packed_reduce(const P* ptrs[], int idx) {
  168. A tmp = upcast(ptrs[0][idx]);
  169. #pragma unroll
  170. for (int i = 1; i < ngpus; i++) {
  171. packed_assign_add(tmp, upcast(ptrs[i][idx]));
  172. }
  173. return downcast<P>(tmp);
  174. }
  175. template <typename T, int ngpus>
  176. __global__ void __launch_bounds__(512, 1)
  177. cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
  178. T* __restrict__ result, int rank, int size) {
  179. using P = typename packed_t<T>::P;
  180. using A = typename packed_t<T>::A;
  181. // note: we don't reorder the address so the accumulation order is the same
  182. // for all ranks, ensuring bitwise identical results
  183. auto dp = *_dp;
  184. multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
  185. // do the actual reduction
  186. for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
  187. idx += gridDim.x * blockDim.x) {
  188. ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
  189. }
  190. multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
  191. }
  192. template <typename P>
  193. DINLINE P* get_tmp_buf(Signal* sg) {
  194. return (P*)(((Signal*)sg) + 1);
  195. }
  196. template <typename T, int ngpus>
  197. __global__ void __launch_bounds__(512, 1)
  198. cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
  199. T* __restrict__ result, int rank, int size) {
  200. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  201. int stride = gridDim.x * blockDim.x;
  202. using P = typename packed_t<T>::P;
  203. using A = typename packed_t<T>::A;
  204. int part = size / ngpus;
  205. int start = rank * part;
  206. int end = rank == ngpus - 1 ? size : start + part;
  207. int largest_part = part + size % ngpus;
  208. const P* ptrs[ngpus];
  209. P* tmps[ngpus];
  210. #pragma unroll
  211. for (int i = 0; i < ngpus; i++) {
  212. int target = (rank + i) % ngpus;
  213. ptrs[i] = (const P*)_dp->ptrs[target];
  214. tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  215. }
  216. auto tmp_out = tmps[0];
  217. multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
  218. // stage 1: reduce scatter
  219. for (int idx = start + tid; idx < end; idx += stride) {
  220. tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  221. }
  222. multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
  223. // stage 2: allgather. Note: it's important to match the tid between
  224. // the two stages, because visibility across devices is only guaranteed
  225. // between threads that have the same tid. If thread i computes the sum of
  226. // start + i in the first stage, then thread i also gathers start + i from all
  227. // ranks.
  228. for (int idx = tid; idx < largest_part; idx += stride) {
  229. #pragma unroll
  230. for (int i = 0; i < ngpus; i++) {
  231. int gather_from_rank = ((rank + i) % ngpus);
  232. if (gather_from_rank == ngpus - 1 || idx < part) {
  233. int dst_idx = gather_from_rank * part + idx;
  234. ((P*)result)[dst_idx] = tmps[i][idx];
  235. }
  236. }
  237. }
  238. }
  239. using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
  240. static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
  241. static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
  242. class CustomAllreduce {
  243. public:
  244. int rank_;
  245. int world_size_;
  246. bool full_nvlink_;
  247. // below are device pointers
  248. RankSignals sg_;
  249. std::unordered_map<void*, RankData*> buffers_;
  250. Signal* self_sg_;
  251. // stores the registered device pointers from all ranks
  252. RankData *d_rank_data_base_, *d_rank_data_end_;
  253. std::vector<void*> graph_unreg_buffers_;
  254. // a map from IPC handles to opened IPC pointers
  255. std::map<IPC_KEY, char*> ipc_handles_;
  256. /**
  257. * meta is a pointer to device metadata and temporary buffer for allreduce.
  258. *
  259. * There's a total of sizeof(Signal) of prefix before the actual data,
  260. * so meta + 1 points to actual temporary buffer.
  261. *
  262. * note: this class does not own any device memory. Any required buffers
  263. * are passed in from the constructor
  264. */
  265. CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
  266. const cudaIpcMemHandle_t* handles,
  267. const std::vector<int64_t>& offsets, int rank,
  268. bool full_nvlink = true)
  269. : rank_(rank),
  270. world_size_(offsets.size()),
  271. full_nvlink_(full_nvlink),
  272. self_sg_(meta),
  273. d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
  274. d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
  275. for (int i = 0; i < world_size_; i++) {
  276. Signal* rank_sg;
  277. if (i != rank_) {
  278. char* handle = open_ipc_handle(&handles[i]);
  279. handle += offsets[i];
  280. rank_sg = (Signal*)handle;
  281. } else {
  282. rank_sg = self_sg_;
  283. }
  284. sg_.signals[i] = rank_sg;
  285. }
  286. }
  287. char* open_ipc_handle(const void* ipc_handle) {
  288. auto [it, new_handle] =
  289. ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
  290. if (new_handle) {
  291. char* ipc_ptr;
  292. CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
  293. *((const cudaIpcMemHandle_t*)ipc_handle),
  294. cudaIpcMemLazyEnablePeerAccess));
  295. it->second = ipc_ptr;
  296. }
  297. return it->second;
  298. }
  299. std::pair<std::vector<uint8_t>, std::vector<int64_t>>
  300. get_graph_buffer_ipc_meta() {
  301. auto num_buffers = graph_unreg_buffers_.size();
  302. auto handle_sz = sizeof(cudaIpcMemHandle_t);
  303. std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
  304. std::vector<int64_t> offsets(num_buffers);
  305. for (int i = 0; i < num_buffers; i++) {
  306. auto ptr = graph_unreg_buffers_[i];
  307. void* base_ptr;
  308. // note: must share the base address of each allocation, or we get wrong
  309. // address
  310. if (cuPointerGetAttribute(&base_ptr,
  311. CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
  312. (CUdeviceptr)ptr) != CUDA_SUCCESS)
  313. throw std::runtime_error("failed to get pointer attr");
  314. CUDACHECK(cudaIpcGetMemHandle(
  315. (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
  316. offsets[i] = ((char*)ptr) - ((char*)base_ptr);
  317. }
  318. return std::make_pair(handles, offsets);
  319. }
  320. void check_rank_data_capacity(size_t num = 1) {
  321. if (d_rank_data_base_ + num > d_rank_data_end_)
  322. throw std::runtime_error(
  323. "Rank data buffer is overflowed by " +
  324. std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  325. }
  326. void register_buffer(const std::vector<std::string>& handles,
  327. const std::vector<int64_t>& offsets, void* self) {
  328. check_rank_data_capacity();
  329. RankData data;
  330. for (int i = 0; i < world_size_; i++) {
  331. if (i != rank_) {
  332. char* handle = open_ipc_handle(handles[i].data());
  333. handle += offsets[i];
  334. data.ptrs[i] = handle;
  335. } else {
  336. data.ptrs[i] = self;
  337. }
  338. }
  339. auto d_data = d_rank_data_base_++;
  340. CUDACHECK(
  341. cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
  342. buffers_[self] = d_data;
  343. }
  344. // note: when registering graph buffers, we intentionally choose to not
  345. // deduplicate the addresses. That means if the allocator reuses some
  346. // addresses, they will be registered again. This is to account for the remote
  347. // possibility of different allocation patterns between ranks. For example,
  348. // rank 1 may get the same input address for the second allreduce, but rank 2
  349. // got a different address. IPC handles have internal reference counting
  350. // mechanism so overhead should be small.
  351. void register_graph_buffers(
  352. const std::vector<std::string>& handles,
  353. const std::vector<std::vector<int64_t>>& offsets) {
  354. auto num_buffers = graph_unreg_buffers_.size();
  355. check_rank_data_capacity(num_buffers);
  356. std::vector<RankData> rank_data(num_buffers);
  357. for (int i = 0; i < num_buffers; i++) {
  358. auto self_ptr = graph_unreg_buffers_[i];
  359. auto& rd = rank_data[i];
  360. for (int j = 0; j < world_size_; j++) {
  361. if (j != rank_) {
  362. char* handle =
  363. open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
  364. handle += offsets[j][i];
  365. rd.ptrs[j] = handle;
  366. } else {
  367. rd.ptrs[j] = self_ptr;
  368. }
  369. }
  370. }
  371. CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
  372. sizeof(RankData) * num_buffers,
  373. cudaMemcpyHostToDevice));
  374. d_rank_data_base_ += num_buffers;
  375. graph_unreg_buffers_.clear();
  376. }
  377. /**
  378. * This is the result after careful grid search. Using 36 blocks give the best
  379. * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
  380. * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
  381. * Not quite sure the underlying reason, but my guess is that too many SMs
  382. * will cause contention on NVLink bus.
  383. */
  384. template <typename T>
  385. void allreduce(cudaStream_t stream, T* input, T* output, int size,
  386. int threads = 512, int block_limit = 36) {
  387. auto d = packed_t<T>::P::size;
  388. if (size % d != 0)
  389. throw std::runtime_error(
  390. "custom allreduce currently requires input length to be multiple "
  391. "of " +
  392. std::to_string(d));
  393. if (block_limit > kMaxBlocks)
  394. throw std::runtime_error("max supported block limit is " +
  395. std::to_string(kMaxBlocks) + ". Got " +
  396. std::to_string(block_limit));
  397. RankData* ptrs;
  398. cudaStreamCaptureStatus status;
  399. CUDACHECK(cudaStreamIsCapturing(stream, &status));
  400. if (status == cudaStreamCaptureStatusActive) {
  401. ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
  402. graph_unreg_buffers_.push_back(input);
  403. } else {
  404. auto it = buffers_.find(input);
  405. if (it == buffers_.end())
  406. throw std::runtime_error(
  407. "buffer address " +
  408. std::to_string(reinterpret_cast<uint64_t>(input)) +
  409. " is not registered!");
  410. ptrs = it->second;
  411. }
  412. size /= d;
  413. auto bytes = size * sizeof(typename packed_t<T>::P);
  414. int blocks = std::min(block_limit, (size + threads - 1) / threads);
  415. #define KL(ngpus, name) \
  416. name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
  417. rank_, size);
  418. // TODO(hanzhi713): Threshold is different for A100 and H100.
  419. // Add per device threshold.
  420. #define REDUCE_CASE(ngpus) \
  421. case ngpus: { \
  422. if (world_size_ == 2) { \
  423. KL(ngpus, cross_device_reduce_1stage); \
  424. } else if (full_nvlink_) { \
  425. if ((world_size_ <= 4 && bytes < 512 * 1024) || \
  426. (world_size_ <= 8 && bytes < 256 * 1024)) { \
  427. KL(ngpus, cross_device_reduce_1stage); \
  428. } else { \
  429. KL(ngpus, cross_device_reduce_2stage); \
  430. } \
  431. } \
  432. break; \
  433. }
  434. switch (world_size_) {
  435. REDUCE_CASE(2)
  436. REDUCE_CASE(4)
  437. REDUCE_CASE(6)
  438. REDUCE_CASE(8)
  439. default:
  440. throw std::runtime_error(
  441. "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
  442. "gpus = " +
  443. std::to_string(world_size_));
  444. }
  445. #undef REDUCE_CASE
  446. #undef KL
  447. }
  448. ~CustomAllreduce() {
  449. for (auto [_, ptr] : ipc_handles_) {
  450. CUDACHECK(cudaIpcCloseMemHandle(ptr));
  451. }
  452. }
  453. };
  454. /**
  455. * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
  456. a template instantiation:
  457. * template void aphrodite::CustomAllreduce::allreduce<half>(cudaStream_t, half
  458. *, half *, int, int, int);
  459. */
  460. } // namespace aphrodite