Sfoglia il codice sorgente

fix: no repeated IPC registration (#227)

AlpinDale 1 anno fa
parent
commit
6305e6f3f2
1 ha cambiato i file con 25 aggiunte e 18 eliminazioni
  1. 25 18
      kernels/all_reduce/custom_all_reduce.cuh

+ 25 - 18
kernels/all_reduce/custom_all_reduce.cuh

@@ -7,6 +7,7 @@
 
 #include <iostream>
 #include <limits>
+#include <map>
 #include <unordered_map>
 #include <vector>
 
@@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1)
   }
 }
 
+using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
+static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
+static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
+
 class CustomAllreduce {
  public:
   int rank_;
@@ -341,7 +346,8 @@ class CustomAllreduce {
   // stores the registered device pointers from all ranks
   RankData *d_rank_data_base_, *d_rank_data_end_;
   std::vector<void *> graph_unreg_buffers_;
-  std::vector<void *> ipc_handles_;
+  // a map from IPC handles to opened IPC pointers
+  std::map<IPC_KEY, char *> ipc_handles_;
 
   /**
    * meta is a pointer to device metadata and temporary buffer for allreduce.
@@ -365,10 +371,7 @@ class CustomAllreduce {
     for (int i = 0; i < world_size_; i++) {
       Metadata *rank_meta;
       if (i != rank_) {
-        char *handle;
-        CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i],
-                                       cudaIpcMemLazyEnablePeerAccess));
-        ipc_handles_.push_back(handle);
+        char *handle = open_ipc_handle(&handles[i]);
         handle += offsets[i];
         rank_meta = (Metadata *)handle;
       } else {
@@ -378,6 +381,19 @@ class CustomAllreduce {
     }
   }
 
+  char *open_ipc_handle(const void *ipc_handle) {
+    auto [it, new_handle] =
+        ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
+    if (new_handle) {
+      char *ipc_ptr;
+      CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
+                                     *((const cudaIpcMemHandle_t *)ipc_handle),
+                                     cudaIpcMemLazyEnablePeerAccess));
+      it->second = ipc_ptr;
+    }
+    return it->second;
+  }
+
   std::pair<std::vector<uint8_t>, std::vector<int64_t>>
   get_graph_buffer_ipc_meta() {
     auto num_buffers = graph_unreg_buffers_.size();
@@ -413,11 +429,7 @@ class CustomAllreduce {
     RankData data;
     for (int i = 0; i < world_size_; i++) {
       if (i != rank_) {
-        char *handle;
-        CUDACHECK(cudaIpcOpenMemHandle(
-            (void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()),
-            cudaIpcMemLazyEnablePeerAccess));
-        ipc_handles_.push_back(handle);
+        char *handle = open_ipc_handle(handles[i].data());
         handle += offsets[i];
         data.ptrs[i] = handle;
       } else {
@@ -448,13 +460,8 @@ class CustomAllreduce {
       auto &rd = rank_data[i];
       for (int j = 0; j < world_size_; j++) {
         if (j != rank_) {
-          char *handle;
-          CUDACHECK(cudaIpcOpenMemHandle(
-              (void **)&handle,
-              *((cudaIpcMemHandle_t *)&handles[j]
-                                              [i * sizeof(cudaIpcMemHandle_t)]),
-              cudaIpcMemLazyEnablePeerAccess));
-          ipc_handles_.push_back(handle);
+          char *handle =
+              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
           handle += offsets[j][i];
           rd.ptrs[j] = handle;
         } else {
@@ -541,7 +548,7 @@ class CustomAllreduce {
   }
 
   ~CustomAllreduce() {
-    for (auto ptr : ipc_handles_) {
+    for (auto [_, ptr] : ipc_handles_) {
       CUDACHECK(cudaIpcCloseMemHandle(ptr));
     }
   }