ln.h 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. #pragma once
  2. #include <unordered_map>
  3. #include <cuda_fp16.h>
  4. #include <cuda_bf16.h>
  5. #ifdef OLD_GENERATOR_PATH
  6. #include <ATen/CUDAGeneratorImpl.h>
  7. #else
  8. #include <ATen/cuda/CUDAGeneratorImpl.h>
  9. #endif
  10. namespace layer_norm {
  11. ////////////////////////////////////////////////////////////////////////////////////////////////////
  12. template<typename Params>
  13. struct LaunchParams{
  14. size_t elts_per_thread;
  15. size_t workspace_bytes;
  16. size_t barrier_size;
  17. cudaDeviceProp * props;
  18. cudaStream_t stream;
  19. Params params;
  20. };
  21. ////////////////////////////////////////////////////////////////////////////////////////////////////
  22. struct ParamsBase {
  23. ParamsBase()
  24. : ctas_per_col(0)
  25. , rows(0)
  26. , cols(0)
  27. , x(nullptr)
  28. , mu(nullptr)
  29. , rs(nullptr)
  30. , gamma(nullptr)
  31. , gamma1(nullptr)
  32. , rowscale(nullptr)
  33. , colscale(nullptr)
  34. , dropout_keep_p(1.f)
  35. , dropout_scale(1.f)
  36. , is_rms_norm(false)
  37. , workspace(nullptr)
  38. , barrier(nullptr)
  39. {
  40. }
  41. // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
  42. int ctas_per_col;
  43. // Input is interpreted as matrix. We normalize across columns.
  44. int rows;
  45. int cols;
  46. // Common data pointers.
  47. void *x0;
  48. void *x1;
  49. void *residual;
  50. void *x;
  51. void *dmask;
  52. void *dmask1;
  53. void *mu;
  54. void *rs;
  55. void *gamma;
  56. void *gamma1;
  57. void *rowscale;
  58. void *colscale;
  59. void *x0_subset;
  60. void *z_subset;
  61. float inverse_cols;
  62. float dropout_keep_p;
  63. float dropout_scale;
  64. float rowscale_const;
  65. bool is_rms_norm;
  66. // Multi-CTA workspace in gmem.
  67. void *workspace;
  68. // Multi-CTA sync barriers in gmem.
  69. int *barrier;
  70. };
  71. ////////////////////////////////////////////////////////////////////////////////////////////////////
  72. struct FwdParams : public ParamsBase {
  73. FwdParams()
  74. : ParamsBase()
  75. , z(nullptr)
  76. , z1(nullptr)
  77. , beta(nullptr)
  78. , beta1(nullptr)
  79. , epsilon(0.f)
  80. {
  81. }
  82. // Output of LN FWD.
  83. void *z;
  84. void *z1;
  85. void *beta;
  86. void *beta1;
  87. float epsilon;
  88. // Random state.
  89. at::PhiloxCudaState philox_args;
  90. };
  91. ////////////////////////////////////////////////////////////////////////////////////////////////////
  92. struct BwdParams : public ParamsBase {
  93. BwdParams()
  94. : ParamsBase()
  95. , dz(nullptr)
  96. , dz1(nullptr)
  97. , dx(nullptr)
  98. , dbeta_part(nullptr)
  99. , dgamma_part(nullptr)
  100. , dbeta1_part(nullptr)
  101. , dgamma1_part(nullptr)
  102. , dcolscale_part(nullptr)
  103. , dx0(nullptr)
  104. , dx1(nullptr)
  105. , dresidual(nullptr)
  106. , dbeta(nullptr)
  107. , dgamma(nullptr)
  108. , dbeta1(nullptr)
  109. , dgamma1(nullptr)
  110. , dcolscale(nullptr)
  111. {
  112. }
  113. // Input: gradient wrt. LN FWD output.
  114. void *dz;
  115. void *dz1;
  116. // Input: gradient wrt residual.
  117. void *dx;
  118. // Workspace for Wgrad pre-reduction.
  119. void *dbeta_part;
  120. void *dgamma_part;
  121. void *dbeta1_part;
  122. void *dgamma1_part;
  123. void *dcolscale_part;
  124. // Output: Dgrad.
  125. void *dx0;
  126. void *dx1;
  127. void *dresidual;
  128. // Output: Wgrad.
  129. void *dbeta;
  130. void *dgamma;
  131. void *dbeta1;
  132. void *dgamma1;
  133. void *dcolscale;
  134. };
  135. ////////////////////////////////////////////////////////////////////////////////////////////////////
  136. using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
  137. using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
  138. using FunctionKey = uint64_t;
  139. using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
  140. using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
  141. extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
  142. extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
  143. ////////////////////////////////////////////////////////////////////////////////////////////////////
  144. using fp32 = float;
  145. using fp16 = half;
  146. using bf16 = nv_bfloat16;
  147. ////////////////////////////////////////////////////////////////////////////////////////////////////
  148. template<typename T>
  149. struct TypeId{};
  150. template<>
  151. struct TypeId<fp16>{
  152. constexpr static uint32_t Value = 0;
  153. };
  154. template<>
  155. struct TypeId<bf16>{
  156. constexpr static uint32_t Value = 1;
  157. };
  158. template<>
  159. struct TypeId<fp32>{
  160. constexpr static uint32_t Value = 2;
  161. };
  162. ////////////////////////////////////////////////////////////////////////////////////////////////////
  163. template<typename T, int S>
  164. struct Type2Key{
  165. constexpr static uint32_t Value = TypeId<T>::Value << S;
  166. };
  167. ////////////////////////////////////////////////////////////////////////////////////////////////////
  168. template<typename T>
  169. struct WeightType2Key : public Type2Key<T, 0>{};
  170. template<typename T>
  171. struct InputType2Key : public Type2Key<T, 2>{};
  172. template<typename T>
  173. struct ResidualType2Key : public Type2Key<T, 4>{};
  174. template<typename T>
  175. struct OutputType2Key : public Type2Key<T, 6>{};
  176. template<typename T>
  177. struct ComputeType2Key : public Type2Key<T, 8>{};
  178. ////////////////////////////////////////////////////////////////////////////////////////////////////
  179. template<typename W, typename I, typename R, typename O, typename C>
  180. struct Types2Key{
  181. constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
  182. constexpr static inline uint64_t get(const uint64_t hidden_size){
  183. constexpr uint64_t type_key = Value;
  184. return (type_key << 32) | hidden_size;
  185. }
  186. };
  187. ////////////////////////////////////////////////////////////////////////////////////////////////////
  188. template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
  189. struct FwdRegistrar{
  190. FwdRegistrar(FwdFunction f){
  191. uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
  192. FWD_FUNCS.insert({ key, f });
  193. }
  194. };
  195. ////////////////////////////////////////////////////////////////////////////////////////////////////
  196. template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
  197. struct BwdRegistrar{
  198. BwdRegistrar(BwdFunction f){
  199. uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
  200. BWD_FUNCS.insert({ key, f });
  201. }
  202. };
  203. ////////////////////////////////////////////////////////////////////////////////////////////////////
  204. template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
  205. struct FwdParallelRegistrar{
  206. FwdParallelRegistrar(FwdFunction f){
  207. uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
  208. PARALLEL_FWD_FUNCS.insert({ key, f });
  209. }
  210. };
  211. ////////////////////////////////////////////////////////////////////////////////////////////////////
  212. template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
  213. struct BwdParallelRegistrar{
  214. BwdParallelRegistrar(BwdFunction f){
  215. uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
  216. PARALLEL_BWD_FUNCS.insert({ key, f });
  217. }
  218. };
  219. ////////////////////////////////////////////////////////////////////////////////////////////////////
  220. } // namespace layer_norm