ptx_mma.cuh 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is modified from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh
  17. /***************************************************************************
  18. * Copyright 2023 The FLash-LLM Authors. All rights reserved.
  19. * Licensed under the Apache License, Version 2.0 (the "License");
  20. * you may not use this file except in compliance with the License.
  21. * You may obtain a copy of the License at
  22. * http://www.apache.org/licenses/LICENSE-2.0
  23. * Unless required by applicable law or agreed to in writing, software
  24. * distributed under the License is distributed on an "AS IS" BASIS,
  25. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. * See the License for the specific language governing permissions and
  27. * limitations under the License.
  28. ***************************************************************************/
  29. #ifndef PTX_MMA_CUH
  30. #define PTX_MMA_CUH
  31. #include <cuda.h>
  32. #include <cuda_fp16.h>
  33. #include <cuda_runtime.h>
  34. #include <assert.h>
  35. #include "configs.h"
  36. // MODIFICATION NOTE: to support MSVC
  37. // - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__
  38. // Reg)[4]
  39. // - half __restrict__ (*read_SPTR) is changed to half (* __restrict__
  40. // read_SPTR)
  41. template <typename TilingConfig>
  42. __device__ __forceinline__ void B_FromSharedToReg(
  43. uint32_t (*__restrict__ Reg)[4],
  44. half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
  45. int slice_id) {
  46. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  47. #ifdef DEBUG_MODE
  48. static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
  49. (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0));
  50. #endif
  51. const int warpId = threadIdx.x / WARP_SIZE;
  52. int lane_id = threadIdx.x % WARP_SIZE;
  53. int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;
  54. int warp_start_col =
  55. TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 *
  56. WARP_j; // each warp may start from reading warp_start_col'th column of
  57. // the B tile in shared memory
  58. #ifdef DEBUG_MODE
  59. assert(warp_start_col == 0);
  60. #endif
  61. int col = (lane_id % 8) + (lane_id / 16) * 8;
  62. int row = (lane_id % 16) / 8 * 8;
  63. uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(
  64. &read_SPTR[warp_start_col + col][slice_id * MMA_16 + row]));
  65. if (TilingConfig::WARP_COL_MMA_TENSORS == 1) {
  66. asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
  67. : "=r"(Reg[0][0]), "=r"(Reg[0][1])
  68. : "r"(smem_local_ptr));
  69. } else {
  70. #pragma unroll
  71. for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) {
  72. asm volatile(
  73. "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
  74. : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3])
  75. : "r"(smem_local_ptr));
  76. smem_local_ptr +=
  77. 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);
  78. }
  79. }
  80. #endif
  81. }
  82. // MODIFICATION NOTE: to support MSVC, the function signature is changed from
  83. // MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a,
  84. // uint32_t __restrict__ *b).
  85. __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t* __restrict__ c,
  86. uint32_t* __restrict__ a,
  87. uint32_t* __restrict__ b) {
  88. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  89. asm volatile(
  90. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
  91. "{ %0, %1, %2, %3},"
  92. "{ %4, %5, %6, %7 },"
  93. "{ %8, %9 },"
  94. "{ %10, %11, %12, %13 };"
  95. : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
  96. : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
  97. "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
  98. #endif
  99. }
  100. #endif // PTX_MMA_CUH