1
0

ln_kernel_traits.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #pragma once
  2. ////////////////////////////////////////////////////////////////////////////////////////////////////
  3. namespace layer_norm {
  4. template<
  5. uint32_t HIDDEN_SIZE_,
  6. typename weight_t_,
  7. typename input_t_,
  8. typename residual_t_,
  9. typename output_t_,
  10. typename compute_t_,
  11. typename index_t_,
  12. uint32_t THREADS_PER_CTA_
  13. >
  14. struct Kernel_traits_base {
  15. using weight_t = weight_t_;
  16. using input_t = input_t_;
  17. using residual_t = residual_t_;
  18. using output_t = output_t_;
  19. using compute_t = compute_t_;
  20. using index_t = index_t_;
  21. enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
  22. enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
  23. enum { THREADS_PER_WARP = 32 };
  24. };
  25. ////////////////////////////////////////////////////////////////////////////////////////////////////
  26. template<
  27. uint32_t HIDDEN_SIZE_,
  28. typename weight_t_,
  29. typename input_t_,
  30. typename residual_t_,
  31. typename output_t_,
  32. typename compute_t_,
  33. typename index_t_,
  34. bool Has_colscale,
  35. uint32_t THREADS_PER_CTA_,
  36. uint32_t BYTES_PER_LDG_,
  37. typename Base = Kernel_traits_base<HIDDEN_SIZE_,
  38. weight_t_,
  39. input_t_,
  40. residual_t_,
  41. output_t_,
  42. compute_t_,
  43. index_t_,
  44. THREADS_PER_CTA_>
  45. >
  46. struct Kernel_traits_finalize : public Base {
  47. enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
  48. static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
  49. // Bytes per global load from the input.
  50. enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
  51. // Number of elements fetched by a global load.
  52. enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
  53. // Bytes per global store of the weights.
  54. enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
  55. static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
  56. static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
  57. // The total number of BYTES_PER_LDG-wide words in a hidden vector.
  58. enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
  59. static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
  60. // Shared memory size to transpose the CTA result.
  61. enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
  62. // Shared memory size to coalsece the CTA result.
  63. enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
  64. // Shared memory requirement per CTA.
  65. static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
  66. enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
  67. // The type of the reducer.
  68. using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
  69. // Condition for the whole CTA to participate in syncthreads.
  70. static_assert(COLS % Base::THREADS_PER_WARP == 0);
  71. enum { CTAS = COLS / Base::THREADS_PER_WARP };
  72. };
  73. ////////////////////////////////////////////////////////////////////////////////////////////////////
  74. template<
  75. typename weight_t_,
  76. typename input_t_,
  77. typename residual_t_,
  78. typename output_t_,
  79. typename compute_t_,
  80. typename index_t_,
  81. uint32_t HIDDEN_SIZE_,
  82. uint32_t CTAS_PER_ROW_,
  83. uint32_t WARPS_M_,
  84. uint32_t WARPS_N_,
  85. uint32_t BYTES_PER_LDG_ = 16,
  86. typename Base = Kernel_traits_base<
  87. HIDDEN_SIZE_,
  88. weight_t_,
  89. input_t_,
  90. residual_t_,
  91. output_t_,
  92. compute_t_,
  93. index_t_,
  94. WARPS_M_*WARPS_N_*THREADS_PER_WARP
  95. >
  96. >
  97. struct Kernel_traits : public Base {
  98. using input_t = typename Base::input_t;
  99. using residual_t = typename Base::residual_t;
  100. using weight_t = typename Base::weight_t;
  101. using compute_t = typename Base::compute_t;
  102. using output_t = typename Base::output_t;
  103. using index_t = typename Base::index_t;
  104. // using mask_t = unsigned char;
  105. using mask_t = bool;
  106. enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
  107. enum { WARPS_M = WARPS_M_ };
  108. enum { WARPS_N = WARPS_N_ };
  109. enum { COLS = HIDDEN_SIZE_ };
  110. enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
  111. enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
  112. enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
  113. enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
  114. enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
  115. enum { ROWS_PER_CTA = WARPS_M };
  116. enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
  117. enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
  118. // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
  119. enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
  120. static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
  121. using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
  122. using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
  123. enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
  124. enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
  125. using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
  126. using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
  127. using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
  128. using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
  129. using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
  130. using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
  131. enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
  132. // Assume that each thread can handle the same number of elements in the output and weights as in the input.
  133. static_assert(sizeof(input_t) == sizeof(output_t));
  134. static_assert(sizeof(input_t) <= sizeof(residual_t));
  135. // The number of columns fetched per load from input: one per thread.
  136. enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
  137. // The total number of vectorized loads/stores per hidden vector.
  138. enum { VEC_COLS = COLS / ELTS_PER_LDG };
  139. // The number of loads per thread for the input.
  140. enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
  141. static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
  142. //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
  143. using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
  144. enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
  145. };
  146. ////////////////////////////////////////////////////////////////////////////////////////////////////
  147. } // namespace layer_norm