mma.h 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /*
  2. * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
  3. * Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include "base.h"
  19. #include <cudaTypedefs.h>
  20. namespace marlin_24 {
  21. // On CUDA earlier than 12.5, the ordered_metadata version of this instruction
  22. // is not supported. On later versions of CUDA the version without ordered
  23. // metadata results in the following warning:
  24. // | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
  25. // | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
  26. // | reduced performance on some future architectures
  27. #if defined CUDA_VERSION && CUDA_VERSION >= 12050
  28. #define MMA_SP_INST \
  29. "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
  30. #else
  31. #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
  32. #endif
  33. // m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
  34. // output/accumulation.
  35. __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
  36. const FragA& frag_b, FragC& frag_c, FragM& frag_m,
  37. const int psel) {
  38. const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
  39. const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
  40. const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
  41. const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
  42. float* c = reinterpret_cast<float*>(&frag_c);
  43. if (psel == 0) {
  44. asm volatile(MMA_SP_INST
  45. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
  46. "{%12,%13,%14,%15}, %16, 0x0;\n"
  47. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  48. : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
  49. "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
  50. "f"(c[2]), "f"(c[3]), "r"(e[0]));
  51. asm volatile(MMA_SP_INST
  52. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
  53. "{%12,%13,%14,%15}, %16, 0x0;\n"
  54. : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
  55. : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
  56. "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
  57. "f"(c[6]), "f"(c[7]), "r"(e[0]));
  58. } else {
  59. asm volatile(MMA_SP_INST
  60. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
  61. "{%12,%13,%14,%15}, %16, 0x1;\n"
  62. : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
  63. : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
  64. "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
  65. "f"(c[2]), "f"(c[3]), "r"(e[0]));
  66. asm volatile(MMA_SP_INST
  67. "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
  68. "{%12,%13,%14,%15}, %16, 0x1;\n"
  69. : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
  70. : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
  71. "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
  72. "f"(c[6]), "f"(c[7]), "r"(e[0]));
  73. }
  74. }
  75. // Lookup-table based 3-input logical operation; explicitly used for
  76. // dequantization as the compiler does not seem to automatically recognize it in
  77. // all cases.
  78. template <int lut>
  79. __device__ inline int lop3(int a, int b, int c) {
  80. int res;
  81. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  82. : "=r"(res)
  83. : "r"(a), "r"(b), "r"(c), "n"(lut));
  84. return res;
  85. }
  86. __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
  87. float c3) {
  88. uint2 r;
  89. asm("{\n\t"
  90. ".reg .f16 a, b, c, d; \n\t"
  91. "cvt.rn.f16.f32 a, %2; \n\t"
  92. "cvt.rn.f16.f32 b, %3; \n\t"
  93. "cvt.rn.f16.f32 c, %4; \n\t"
  94. "cvt.rn.f16.f32 d, %5; \n\t"
  95. "mov.b32 %0, {a, b}; \n\t"
  96. "mov.b32 %1, {c, d}; \n\t"
  97. "}"
  98. : "=r"(r.x), "=r"(r.y)
  99. : "f"(c0), "f"(c1), "f"(c2), "f"(c3));
  100. return r;
  101. }
  102. // Constructs destination register by taking bytes from 2 sources (based on
  103. // mask)
  104. template <int start_byte, int mask>
  105. __device__ inline uint32_t prmt(uint32_t a) {
  106. uint32_t res;
  107. asm volatile("prmt.b32 %0, %1, %2, %3;\n"
  108. : "=r"(res)
  109. : "r"(a), "n"(start_byte), "n"(mask));
  110. return res;
  111. }
  112. // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
  113. // values. We mostly follow the strategy in the link below, with some small
  114. // changes:
  115. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  116. __device__ inline FragB dequant_4bit(int q) {
  117. const int LO = 0x000f000f;
  118. const int HI = 0x00f000f0;
  119. const int EX = 0x64006400;
  120. // Guarantee that the `(a & b) | c` operations are LOP3s.
  121. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
  122. int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
  123. // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
  124. // directly into `SUB` and `ADD`.
  125. const int SUB = 0x64086408;
  126. const int MUL = 0x2c002c00;
  127. const int ADD = 0xd480d480;
  128. FragB frag_b;
  129. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  130. *reinterpret_cast<const half2*>(&SUB));
  131. frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
  132. *reinterpret_cast<const half2*>(&MUL),
  133. *reinterpret_cast<const half2*>(&ADD));
  134. return frag_b;
  135. }
  136. // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
  137. // values. We mostly follow the strategy in the link below, with some small
  138. // changes:
  139. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  140. __device__ inline FragB dequant_8bit(int q) {
  141. static constexpr uint32_t mask_for_elt_01 = 0x5250;
  142. static constexpr uint32_t mask_for_elt_23 = 0x5351;
  143. static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
  144. uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
  145. uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
  146. static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
  147. FragB frag_b;
  148. frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
  149. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  150. frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
  151. *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
  152. return frag_b;
  153. }
  154. // Multiply dequantized values by the corresponding quantization scale; used
  155. // only for grouped quantization.
  156. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
  157. half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
  158. frag_b[0] = __hmul2(frag_b[0], s);
  159. frag_b[1] = __hmul2(frag_b[1], s);
  160. }
  161. __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
  162. FragS& s0, float* c4, float* c5, float* c6,
  163. float* c7, FragS& s1) {
  164. *c0 = __fmul_rn(*c0, __half2float(s0[0].x));
  165. *c1 = __fmul_rn(*c1, __half2float(s0[0].y));
  166. *c2 = __fmul_rn(*c2, __half2float(s0[1].x));
  167. *c3 = __fmul_rn(*c3, __half2float(s0[1].y));
  168. *c4 = __fmul_rn(*c4, __half2float(s1[0].x));
  169. *c5 = __fmul_rn(*c5, __half2float(s1[0].y));
  170. *c6 = __fmul_rn(*c6, __half2float(s1[1].x));
  171. *c7 = __fmul_rn(*c7, __half2float(s1[1].y));
  172. }
  173. } // namespace marlin_24