mem.h 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. /*
  2. * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
  3. * Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include "base.h"
  19. namespace marlin_24 {
  20. // Predicated asynchronous global->shared copy; used for inputs A where we apply
  21. // predication to handle batchsizes that are not multiples of 16.
  22. __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
  23. const void* glob_ptr,
  24. bool pred = true,
  25. const bool zfill = false) {
  26. const int BYTES = 16;
  27. int src_in_bytes = (zfill ? 0 : BYTES);
  28. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  29. asm volatile(
  30. "{\n"
  31. " .reg .pred p;\n"
  32. " setp.ne.b32 p, %0, 0;\n"
  33. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  34. "}\n" ::"r"((int)pred),
  35. "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
  36. }
  37. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
  38. bool pred = true) {
  39. const int BYTES = 16;
  40. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  41. asm volatile(
  42. "{\n"
  43. " .reg .pred p;\n"
  44. " setp.ne.b32 p, %0, 0;\n"
  45. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  46. "}\n" ::"r"((int)pred),
  47. "r"(smem), "l"(glob_ptr), "n"(BYTES));
  48. }
  49. // Asynchronous global->shared copy
  50. __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
  51. const int BYTES = 16;
  52. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  53. asm volatile(
  54. "{\n"
  55. " cp.async.cg.shared.global [%0], [%1], %2;\n"
  56. "}\n" ::"r"(smem),
  57. "l"(glob_ptr), "n"(BYTES));
  58. }
  59. // Async copy fence.
  60. __device__ inline void cp_async_fence() {
  61. asm volatile("cp.async.commit_group;\n" ::);
  62. }
  63. // Wait until at most `n` async copy stages are still pending.
  64. template <int n>
  65. __device__ inline void cp_async_wait() {
  66. asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
  67. }
  68. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  69. // memory, directly in tensor core layout.
  70. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
  71. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  72. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  73. asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  74. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  75. : "r"(smem));
  76. }
  77. __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
  78. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
  79. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  80. asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
  81. : "=r"(a[0]), "=r"(a[1])
  82. : "r"(smem));
  83. }
  84. // Instruction for loading a full 16x16 matrix fragment of operand A from shared
  85. // memory, directly in tensor core layout.
  86. __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
  87. uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
  88. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  89. asm volatile(
  90. "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  91. : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
  92. : "r"(smem));
  93. }
  94. // Wait until barrier reaches `count`, then lock for current threadblock.
  95. __device__ inline void barrier_acquire(int* lock, int count) {
  96. if (threadIdx.x == 0) {
  97. int state = -1;
  98. do
  99. // Guarantee that subsequent writes by this threadblock will be visible
  100. // globally.
  101. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  102. : "=r"(state)
  103. : "l"(lock));
  104. while (state != count);
  105. }
  106. __syncthreads();
  107. }
  108. // Release barrier and increment visitation count.
  109. __device__ inline void barrier_release(int* lock, bool reset = false) {
  110. __syncthreads();
  111. if (threadIdx.x == 0) {
  112. if (reset) {
  113. lock[0] = 0;
  114. return;
  115. }
  116. int val = 1;
  117. // Make sure that all writes since acquiring this barrier are visible
  118. // globally, while releasing the barrier.
  119. asm volatile("fence.acq_rel.gpu;\n");
  120. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  121. :
  122. : "l"(lock), "r"(val));
  123. }
  124. }
  125. } // namespace marlin_24