1
0

mem.cuh 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. /*
  2. * Modified by HandH1998
  3. * Modified by Neural Magic
  4. * Copyright (C) Marlin.2024 Elias Frantar
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #pragma once
  19. // Predicated asynchronous global->shared copy; used for inputs A where we apply
  20. // predication to handle batchsizes that are not multiples of 16.
  21. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
  22. bool pred = true) {
  23. const int BYTES = 16;
  24. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  25. asm volatile(
  26. "{\n"
  27. " .reg .pred p;\n"
  28. " setp.ne.b32 p, %0, 0;\n"
  29. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  30. "}\n" ::"r"((int)pred),
  31. "r"(smem), "l"(glob_ptr), "n"(BYTES));
  32. }
  33. // Asynchronous global->shared copy
  34. __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
  35. const int BYTES = 16;
  36. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  37. asm volatile(
  38. "{\n"
  39. " cp.async.cg.shared.global [%0], [%1], %2;\n"
  40. "}\n" ::"r"(smem),
  41. "l"(glob_ptr), "n"(BYTES));
  42. }
  43. // Async copy fence.
  44. __device__ inline void cp_async_fence() {
  45. asm volatile("cp.async.commit_group;\n" ::);
  46. }
  47. // Wait until at most `n` async copy stages are still pending.
  48. template <int n>
  49. __device__ inline void cp_async_wait() {
  50. asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
  51. }
  52. // Wait until barrier reaches `count`, then lock for current threadblock.
  53. __device__ inline void barrier_acquire(int* lock, int count) {
  54. if (threadIdx.x == 0) {
  55. int state = -1;
  56. do
  57. // Guarantee that subsequent writes by this threadblock will be visible
  58. // globally.
  59. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
  60. : "=r"(state)
  61. : "l"(lock));
  62. while (state != count);
  63. }
  64. __syncthreads();
  65. }
  66. // Release barrier and increment visitation count.
  67. __device__ inline void barrier_release(int* lock, bool reset = false) {
  68. __syncthreads();
  69. if (threadIdx.x == 0) {
  70. if (reset) {
  71. lock[0] = 0;
  72. return;
  73. }
  74. int val = 1;
  75. // Make sure that all writes since acquiring this barrier are visible
  76. // globally, while releasing the barrier.
  77. asm volatile("fence.acq_rel.gpu;\n");
  78. asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
  79. :
  80. : "l"(lock), "r"(val));
  81. }
  82. }