1
0

ln_utils.cuh 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. #pragma once
  2. #include <cassert>
  3. #include <cuda_bf16.h>
  4. #include <cuda_fp16.h>
  5. #include "ln.h"
  6. ////////////////////////////////////////////////////////////////////////////////////////////////////
  7. constexpr uint32_t THREADS_PER_WARP = 32;
  8. ////////////////////////////////////////////////////////////////////////////////////////////////////
  9. inline void check_cuda_(cudaError_t status, const char *file, int line) {
  10. if( status != cudaSuccess ) {
  11. fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
  12. exit(status);
  13. }
  14. }
  15. ////////////////////////////////////////////////////////////////////////////////////////////////////
  16. #define CHECK_CUDA(ans) \
  17. { check_cuda_((ans), __FILE__, __LINE__); }
  18. ////////////////////////////////////////////////////////////////////////////////////////////////////
  19. #define DIVUP(x, y) (((x) + ((y)-1)) / (y))
  20. ////////////////////////////////////////////////////////////////////////////////////////////////////
  21. #define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
  22. void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
  23. const bool configure_params) { \
  24. launch_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
  25. launch_params, configure_params); \
  26. } \
  27. static FwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
  28. ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
  29. ////////////////////////////////////////////////////////////////////////////////////////////////////
  30. #define REGISTER_BWD_LAUNCHER( \
  31. HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
  32. void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
  33. const bool configure_params) { \
  34. launch_<WTYPE, \
  35. ITYPE, \
  36. RTYPE, \
  37. OTYPE, \
  38. CTYPE, \
  39. uint32_t, \
  40. HIDDEN_SIZE, \
  41. CTAS_PER_ROW, \
  42. WARPS_M, \
  43. WARPS_N, \
  44. BYTES_PER_LDG, \
  45. BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
  46. } \
  47. static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
  48. ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
  49. ////////////////////////////////////////////////////////////////////////////////////////////////////
  50. #define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
  51. void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
  52. const bool configure_params) { \
  53. launch_parallel_residual_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
  54. launch_params, configure_params); \
  55. } \
  56. static FwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
  57. ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
  58. ////////////////////////////////////////////////////////////////////////////////////////////////////
  59. #define REGISTER_PARALLEL_BWD_LAUNCHER( \
  60. HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
  61. void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
  62. const bool configure_params) { \
  63. launch_parallel_residual_<WTYPE, \
  64. ITYPE, \
  65. RTYPE, \
  66. OTYPE, \
  67. CTYPE, \
  68. uint32_t, \
  69. HIDDEN_SIZE, \
  70. CTAS_PER_ROW, \
  71. WARPS_M, \
  72. WARPS_N, \
  73. BYTES_PER_LDG, \
  74. BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
  75. } \
  76. static BwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
  77. ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
  78. ////////////////////////////////////////////////////////////////////////////////////////////////////
  79. inline __device__ float2 operator+(const float2 & a, const float2 & b){
  80. return {a.x + b.x, a.y + b.y};
  81. }
  82. ////////////////////////////////////////////////////////////////////////////////////////////////////
  83. inline __device__ void operator+=(float2 & a, const float2 & b){
  84. a.x += b.x;
  85. a.y += b.y;
  86. }
  87. ////////////////////////////////////////////////////////////////////////////////////////////////////
  88. template<typename T>
  89. struct Sum {
  90. inline __device__ Sum(){}
  91. inline __device__ T operator()(const T &a, const T &b){
  92. return a + b;
  93. }
  94. };
  95. ////////////////////////////////////////////////////////////////////////////////////////////////////
  96. template<typename T>
  97. inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
  98. return __shfl_xor_sync(uint32_t(-1), x, idx);
  99. }
  100. template<>
  101. inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
  102. return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
  103. }
  104. template<typename T>
  105. inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
  106. return __shfl_down_sync(uint32_t(-1), x, idx);
  107. }
  108. template<>
  109. inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
  110. return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
  111. }
  112. ////////////////////////////////////////////////////////////////////////////////////////////////////
  113. namespace layer_norm {
  114. ////////////////////////////////////////////////////////////////////////////////////////////////////
  115. struct uint16 {
  116. uint4 u;
  117. uint4 v;
  118. uint4 s;
  119. uint4 t;
  120. };
  121. ////////////////////////////////////////////////////////////////////////////////////////////////////
  122. struct uint8 {
  123. uint4 u;
  124. uint4 v;
  125. };
  126. ////////////////////////////////////////////////////////////////////////////////////////////////////
  127. template<int BYTES>
  128. struct BytesToType {};
  129. template<>
  130. struct BytesToType<64> {
  131. using Type = uint16;
  132. static_assert(sizeof(Type) == 64);
  133. };
  134. template<>
  135. struct BytesToType<32> {
  136. using Type = uint8;
  137. static_assert(sizeof(Type) == 32);
  138. };
  139. template<>
  140. struct BytesToType<16> {
  141. using Type = uint4;
  142. static_assert(sizeof(Type) == 16);
  143. };
  144. template<>
  145. struct BytesToType<8> {
  146. using Type = uint64_t;
  147. static_assert(sizeof(Type) == 8);
  148. };
  149. template<>
  150. struct BytesToType<4> {
  151. using Type = uint32_t;
  152. static_assert(sizeof(Type) == 4);
  153. };
  154. template<>
  155. struct BytesToType<2> {
  156. using Type = uint16_t;
  157. static_assert(sizeof(Type) == 2);
  158. };
  159. template<>
  160. struct BytesToType<1> {
  161. using Type = uint8_t;
  162. static_assert(sizeof(Type) == 1);
  163. };
  164. ////////////////////////////////////////////////////////////////////////////////////////////////////
  165. template<typename T>
  166. struct TypeToVec2 {};
  167. template<>
  168. struct TypeToVec2<float> {
  169. using Type = float2;
  170. };
  171. template<>
  172. struct TypeToVec2<half> {
  173. using Type = half2;
  174. };
  175. template<>
  176. struct TypeToVec2<nv_bfloat16> {
  177. using Type = nv_bfloat162;
  178. };
  179. ////////////////////////////////////////////////////////////////////////////////////////////////////
  180. template<int INDEX>
  181. struct Get {
  182. template<typename T, typename R>
  183. static inline __device__ R of(const T &vec);
  184. };
  185. template<>
  186. template<typename T, typename R>
  187. inline __device__ R Get<0>::of(const T &vec) {
  188. return vec.x;
  189. }
  190. template<>
  191. template<typename T, typename R>
  192. inline __device__ R Get<1>::of(const T &vec) {
  193. return vec.y;
  194. }
  195. template<>
  196. template<typename T, typename R>
  197. inline __device__ R Get<2>::of(const T &vec) {
  198. return vec.z;
  199. }
  200. template<>
  201. template<typename T, typename R>
  202. inline __device__ R Get<3>::of(const T &vec) {
  203. return vec.w;
  204. }
  205. ////////////////////////////////////////////////////////////////////////////////////////////////////
  206. template<typename Src, typename Dst>
  207. struct Converter{
  208. static inline __device__ Dst convert(const Src &from) {
  209. return Dst(from);
  210. }
  211. };
  212. template<>
  213. struct Converter<float2, half2>{
  214. static inline __device__ half2 convert(const float2 &x) {
  215. return __float22half2_rn(x);
  216. }
  217. };
  218. template<>
  219. struct Converter<float2, nv_bfloat162>{
  220. static inline __device__ nv_bfloat162 convert(const float2 &x) {
  221. #if __CUDA_ARCH__ >= 800
  222. return __float22bfloat162_rn(x);
  223. #else
  224. union {
  225. nv_bfloat162 raw;
  226. nv_bfloat16 x;
  227. nv_bfloat16 y;
  228. } tmp;
  229. tmp.x = __float2bfloat16_rn(x.x);
  230. tmp.y = __float2bfloat16_rn(x.y);
  231. return tmp.raw;
  232. #endif
  233. }
  234. };
  235. ////////////////////////////////////////////////////////////////////////////////////////////////////
  236. template<typename T>
  237. struct Zeros{
  238. static inline __device__ T get() {
  239. return T(0.f);
  240. }
  241. };
  242. template<>
  243. struct Zeros<float2>{
  244. static inline __device__ float2 get() {
  245. return make_float2(0.f, 0.f);
  246. }
  247. };
  248. ////////////////////////////////////////////////////////////////////////////////////////////////////
  249. template<typename Elt_type, uint32_t NUM_ELT>
  250. struct Vec {
  251. enum { BYTES = NUM_ELT * sizeof(Elt_type) };
  252. using Vec_type = typename BytesToType<BYTES>::Type;
  253. using Alias_type = union {
  254. Vec_type vec;
  255. Elt_type elt[NUM_ELT];
  256. };
  257. Alias_type data;
  258. template<typename S>
  259. inline __device__ void to(Vec<S, NUM_ELT> &other) {
  260. #pragma unroll
  261. for( int it = 0; it < NUM_ELT; it++ ) {
  262. other.data.elt[it] = S(this->data.elt[it]);
  263. }
  264. }
  265. template<typename Op>
  266. inline __device__ void assign(const Op &op) {
  267. #pragma unroll
  268. for( int it = 0; it < NUM_ELT; it++ ) {
  269. this->data.elt[it] = op(it);
  270. }
  271. }
  272. inline __device__ void zero_() {
  273. #pragma unroll
  274. for( int it = 0; it < NUM_ELT; it++ ) {
  275. this->data.elt[it] = Elt_type(0.f);
  276. }
  277. }
  278. inline __device__ void load_from(const void *base_ptr, const size_t idx) {
  279. this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
  280. }
  281. inline __device__ void store_to(void *base_ptr, const size_t idx) {
  282. static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
  283. }
  284. };
  285. ////////////////////////////////////////////////////////////////////////////////////////////////////
  286. template<uint32_t CTAS_PER_ROW>
  287. struct InterCTASync {
  288. template<typename Params>
  289. inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
  290. : phase_counter_(0)
  291. , b0_(params.barrier + bidm) // The barrier for this group of CTAs.
  292. , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
  293. {
  294. // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
  295. }
  296. inline __device__ void spin_wait_(int *barrier, int step, int expected) {
  297. asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
  298. for( int found = -1; found != expected; ) {
  299. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
  300. }
  301. }
  302. inline __device__ void sync(){
  303. // ALL THREADS MUST ENTER!
  304. // We switch barrier every iteration.
  305. int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
  306. // We decrement every other iteration.
  307. bool dec = phase_counter_ & 0x2;
  308. int step = dec ? -1 : 1;
  309. int expected = dec ? 0 : CTAS_PER_ROW;
  310. // There are only 4 phases: up/down for b0/b1.
  311. phase_counter_ = (phase_counter_ + 1) & 0x3;
  312. if( threadIdx.x == 0 ) {
  313. spin_wait_(barrier, step, expected);
  314. }
  315. // CTA waits for thread 0
  316. __syncthreads();
  317. }
  318. int phase_counter_;
  319. int * b0_;
  320. int * b1_;
  321. };
  322. ////////////////////////////////////////////////////////////////////////////////////////////////////
  323. template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
  324. struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
  325. using InterCTASync = InterCTASync<CTAS_PER_ROW>;
  326. using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  327. using Type = typename Base::Type;
  328. enum { SMEM_BYTES = Base::SMEM_BYTES };
  329. enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
  330. enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
  331. // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
  332. enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
  333. template<typename Params>
  334. inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
  335. : Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
  336. , inter_cta_(params, bidm, bidn)
  337. , bidn_(bidn) // CTA id within the group.
  338. , w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
  339. , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
  340. {
  341. }
  342. template<typename Op>
  343. inline __device__ T allreduce(T data, Op &op) {
  344. data = Base::reduce(data, op);
  345. // We switch workspace every iteration.
  346. T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
  347. // Warp leaders 0 hold the CTA-local results.
  348. if( this->warp_n_ == 0 && this->lane_ == 0 ) {
  349. workspace[bidn_] = data;
  350. }
  351. inter_cta_.sync();
  352. static_assert(CTAS_PER_ROW <= 32);
  353. T total = Zeros<T>::get();
  354. if(this->lane_ < CTAS_PER_ROW){
  355. total = workspace[this->lane_];
  356. }
  357. total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
  358. return total;
  359. }
  360. InterCTASync inter_cta_;
  361. T *w0_;
  362. T *w1_;
  363. int bidn_;
  364. };
  365. ////////////////////////////////////////////////////////////////////////////////////////////////////
  366. template<typename T, uint32_t WARPS_M>
  367. struct Reducer<T, 1, WARPS_M, 1> {
  368. using Type = T;
  369. enum { SMEM_BYTES = 0 };
  370. enum { WORKSPACE_BYTES_PER_GROUP = 0 };
  371. enum { THREADS_PER_WARP = 32 };
  372. template<typename Params>
  373. inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
  374. : warp_n_(warp_n)
  375. , lane_(lane)
  376. {
  377. }
  378. template<typename Op>
  379. static inline __device__ T allreduce_(T data, Op &op) {
  380. #pragma unroll
  381. for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
  382. data = op(data, warp_shuffle_xor(data, it));
  383. }
  384. return data;
  385. }
  386. template<typename Op>
  387. inline __device__ T allreduce(T data, Op &op) {
  388. return allreduce_(data, op);
  389. }
  390. template<typename Op>
  391. inline __device__ T reduce(T data, Op &op){
  392. // only lane 0 holds the result!
  393. #pragma unroll
  394. for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
  395. data = op(data, warp_shuffle_down(data, it));
  396. }
  397. return data;
  398. }
  399. int warp_n_;
  400. int lane_;
  401. };
  402. ////////////////////////////////////////////////////////////////////////////////////////////////////
  403. template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
  404. struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
  405. using Base = Reducer<T, 1, WARPS_M, 1>;
  406. using Type = T;
  407. enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
  408. enum { WORKSPACE_BYTES_PER_GROUP = 0 };
  409. enum { THREADS_PER_WARP = 32 };
  410. template<typename Params>
  411. inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
  412. : Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
  413. , use0_(true)
  414. {
  415. smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
  416. smem1_ = smem0_ + WARPS_M * WARPS_N;
  417. }
  418. template<typename Op>
  419. inline __device__ T allreduce(T data, Op & op) {
  420. T * smem = use0_ ? smem0_ : smem1_;
  421. use0_ = !use0_;
  422. data = Base::reduce(data, op);
  423. if( this->lane_ == 0 ) {
  424. smem[this->warp_n_] = data;
  425. }
  426. __syncthreads();
  427. T out = Zeros<T>::get();
  428. #pragma unroll
  429. for( int it = 0; it < WARPS_N; it++ ) {
  430. out = op(out, smem[it]);
  431. }
  432. return out;
  433. }
  434. template<typename Op>
  435. inline __device__ T reduce(T data, Op &op) {
  436. T * smem = use0_ ? smem0_ : smem1_;
  437. use0_ = !use0_;
  438. // only intra-CTA group leader holds the result!
  439. data = Base::reduce(data, op);
  440. if( this->lane_ == 0 ) {
  441. smem[this->warp_n_] = data;
  442. }
  443. __syncthreads();
  444. T out = Zeros<T>::get();
  445. if( this->warp_n_ == 0 && this->lane_ == 0 ) {
  446. #pragma unroll
  447. for( int it = 0; it < WARPS_N; it++ ) {
  448. out = op(out, smem[it]);
  449. }
  450. }
  451. return out;
  452. }
  453. T * smem0_;
  454. T * smem1_;
  455. bool use0_;
  456. };
  457. ////////////////////////////////////////////////////////////////////////////////////////////////////
  458. template<typename T, typename int_t>
  459. inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){
  460. //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
  461. const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
  462. #pragma unroll
  463. for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
  464. // Exchange
  465. int_t n_b = warp_shuffle_down(n_a, step);
  466. T m_b = warp_shuffle_down(m_a, step);
  467. T m2_b = warp_shuffle_down(m2_a, step);
  468. // Update
  469. const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both.
  470. const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
  471. const T delta = m_a - m_b;
  472. const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
  473. const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
  474. n_a = n_ab;
  475. m_a = m_ab;
  476. m2_a = m2_ab;
  477. }
  478. // Intra-warp broadcast (only lane 0 has valid stats).
  479. m_a = __shfl_sync(uint32_t(-1), m_a, 0);
  480. m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
  481. }
  482. ////////////////////////////////////////////////////////////////////////////////////////////////////
  483. template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
  484. struct Stats {
  485. // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
  486. using InterCTASync = InterCTASync<CTAS_PER_ROW>;
  487. using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
  488. using stats_t = typename BlockStats::stats_t;
  489. enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
  490. template<typename Params>
  491. inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
  492. : inter_cta_(params, bidm, bidn)
  493. , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
  494. , bidn_(bidn) // CTA id within the group.
  495. , w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
  496. , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
  497. , warp_n_(warp_n)
  498. , lane_(lane)
  499. {
  500. }
  501. template<uint32_t N>
  502. inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
  503. constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
  504. // TODO rn is not really needed here..
  505. constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
  506. stats_t block_stats = block_stats_.compute(elts, block_rn);
  507. stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
  508. if( warp_n_ == 0 && lane_ == 0 ) {
  509. workspace[bidn_] = block_stats;
  510. }
  511. // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
  512. inter_cta_.sync();
  513. T n = Zeros<T>::get();
  514. T m = Zeros<T>::get();
  515. T m2 = Zeros<T>::get();
  516. // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
  517. static_assert(CTAS_PER_ROW <= 32);
  518. // Every warp does the final reduction locally.
  519. if( lane_ < CTAS_PER_ROW ) {
  520. stats_t result = workspace[lane_];
  521. n = ELTS_PER_ROW_PER_CTA;
  522. m = layer_norm::Get<0>::of<stats_t, T>(result);
  523. m2 = layer_norm::Get<1>::of<stats_t, T>(result);
  524. }
  525. warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
  526. return { m, m2 };
  527. }
  528. InterCTASync inter_cta_;
  529. BlockStats block_stats_;
  530. stats_t *w0_;
  531. stats_t *w1_;
  532. int bidn_;
  533. int warp_n_;
  534. int lane_;
  535. };
  536. ////////////////////////////////////////////////////////////////////////////////////////////////////
  537. template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
  538. struct Stats<T, 1, WARPS_M, WARPS_N> {
  539. using WarpStats = Stats<T, 1, WARPS_M, 1>;
  540. using stats_t = typename WarpStats::stats_t;
  541. enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
  542. template<typename Params>
  543. inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
  544. : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
  545. , use0_(true)
  546. {
  547. smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
  548. smem1_ = smem0_ + WARPS_M * WARPS_N;
  549. }
  550. template<bool Is_even_cols, uint32_t N, typename function_t>
  551. inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
  552. function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
  553. stats_t * smem = use0_ ? smem0_ : smem1_;
  554. use0_ = !use0_;
  555. // Compute warp local for all WARPS_N
  556. const auto warp_n = warp_stats_.reducer_.warp_n_;
  557. const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n));
  558. stats_t warp_stats = warp_stats_.template compute<Is_even_cols>(
  559. elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts
  560. );
  561. //Each warp warp leader stores its stats
  562. const auto lane = warp_stats_.reducer_.lane_;
  563. if( lane == 0 ) {
  564. smem[warp_n] = warp_stats;
  565. }
  566. __syncthreads();
  567. int n = 0;;
  568. T m = Zeros<T>::get();
  569. T m2 = Zeros<T>::get();
  570. // Assume that there are less than 32 warps, such that we can finalize with a single warp
  571. static_assert(WARPS_N <= 32);
  572. if(lane < WARPS_N){
  573. stats_t result = smem[lane];
  574. n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane);
  575. m = layer_norm::Get<0>::of<stats_t, T>(result);
  576. m2 = layer_norm::Get<1>::of<stats_t, T>(result);
  577. }
  578. warp_chan_upd_dynamic(m, m2, n, WARPS_N);
  579. return { m, m2 };
  580. }
  581. WarpStats warp_stats_;
  582. stats_t * smem0_;
  583. stats_t * smem1_;
  584. bool use0_;
  585. };
  586. ////////////////////////////////////////////////////////////////////////////////////////////////////
  587. template<typename T, uint32_t WARPS_M>
  588. struct Stats<T, 1, WARPS_M, 1> {
  589. using stats_t = typename TypeToVec2<T>::Type;
  590. // The simple Warp reducer.
  591. using Reducer = Reducer<T, 1, WARPS_M, 1>;
  592. enum { SMEM_BYTES = 0 };
  593. template<typename Params>
  594. inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
  595. : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
  596. {
  597. }
  598. template<bool Is_even_cols, uint32_t N, typename function_t>
  599. inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
  600. // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) {
  601. function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
  602. auto sum = Sum<T>();
  603. T m = Zeros<T>::get();
  604. #pragma unroll
  605. for( int it = 0; it < N; it++ ) {
  606. if (Is_even_cols || (it < num_valid_elts)) {
  607. m += elts[it];
  608. }
  609. }
  610. m = reducer_.allreduce(m, sum) * row_norm_factor;
  611. T m2 = Zeros<T>::get();
  612. #pragma unroll
  613. for( int it = 0; it < N; it++ ) {
  614. if (Is_even_cols || (it < num_valid_elts)) {
  615. T diff = (elts[it] - m);
  616. m2 += diff * diff;
  617. }
  618. }
  619. m2 = reducer_.allreduce(m2, sum);
  620. return {m, m2};
  621. }
  622. Reducer reducer_;
  623. };
  624. ////////////////////////////////////////////////////////////////////////////////////////////////////
  625. } // namespace layer_norm