utils.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #include <numa.h>
  2. #include <unistd.h>
  3. #include <string>
  4. #include <sched.h>
  5. #include "cpu_types.hpp"
  6. std::string init_cpu_threads_env(const std::string& cpu_ids) {
  7. bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
  8. TORCH_CHECK(omp_cpu_mask->size > 0);
  9. std::vector<int> omp_cpu_ids;
  10. omp_cpu_ids.reserve(omp_cpu_mask->size);
  11. constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
  12. for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
  13. unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
  14. int i = 0;
  15. while (group_mask) {
  16. if (group_mask & 1) {
  17. omp_cpu_ids.emplace_back(offset + i);
  18. }
  19. ++i;
  20. group_mask >>= 1;
  21. }
  22. }
  23. // Memory node binding
  24. if (numa_available() != -1) {
  25. int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
  26. bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
  27. bitmask* src_mask = numa_get_membind();
  28. int pid = getpid();
  29. // move all existing pages to the specified numa node.
  30. *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
  31. int page_num = numa_migrate_pages(pid, src_mask, mask);
  32. if (page_num == -1) {
  33. TORCH_CHECK(false,
  34. "numa_migrate_pages failed. errno: " + std::to_string(errno));
  35. }
  36. // restrict memory allocation node.
  37. numa_set_membind(mask);
  38. numa_set_strict(1);
  39. }
  40. // OMP threads binding
  41. omp_set_num_threads((int)omp_cpu_ids.size());
  42. torch::set_num_threads((int)omp_cpu_ids.size());
  43. TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
  44. TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
  45. std::vector<std::pair<int, int>> thread_core_mapping;
  46. thread_core_mapping.reserve(omp_cpu_ids.size());
  47. omp_lock_t writelock;
  48. omp_init_lock(&writelock);
  49. #pragma omp parallel for schedule(static, 1)
  50. for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
  51. cpu_set_t mask;
  52. CPU_ZERO(&mask);
  53. CPU_SET(omp_cpu_ids[i], &mask);
  54. int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask);
  55. if (ret == -1) {
  56. TORCH_CHECK(false,
  57. "sched_setaffinity failed. errno: " + std::to_string(errno));
  58. }
  59. omp_set_lock(&writelock);
  60. thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]);
  61. omp_unset_lock(&writelock);
  62. }
  63. omp_destroy_lock(&writelock);
  64. numa_free_nodemask(omp_cpu_mask);
  65. std::stringstream ss;
  66. ss << "OMP threads binding of Process " << getpid() << ":\n";
  67. std::sort(thread_core_mapping.begin(), thread_core_mapping.end(),
  68. [](auto&& a, auto&& b) { return a.second < b.second; });
  69. for (auto&& item : thread_core_mapping) {
  70. ss << "\t"
  71. << "OMP tid: " << item.first << ", core " << item.second << "\n";
  72. }
  73. return ss.str();
  74. }