123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783 |
- #pragma once
- #include <cassert>
- #include <cuda_bf16.h>
- #include <cuda_fp16.h>
- #include "ln.h"
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- constexpr uint32_t THREADS_PER_WARP = 32;
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- inline void check_cuda_(cudaError_t status, const char *file, int line) {
- if( status != cudaSuccess ) {
- fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
- exit(status);
- }
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- #define CHECK_CUDA(ans) \
- { check_cuda_((ans), __FILE__, __LINE__); }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- #define DIVUP(x, y) (((x) + ((y)-1)) / (y))
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- #define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
- void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
- const bool configure_params) { \
- launch_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
- launch_params, configure_params); \
- } \
- static FwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
- ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- #define REGISTER_BWD_LAUNCHER( \
- HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
- void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
- const bool configure_params) { \
- launch_<WTYPE, \
- ITYPE, \
- RTYPE, \
- OTYPE, \
- CTYPE, \
- uint32_t, \
- HIDDEN_SIZE, \
- CTAS_PER_ROW, \
- WARPS_M, \
- WARPS_N, \
- BYTES_PER_LDG, \
- BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
- } \
- static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
- ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- #define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
- void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
- const bool configure_params) { \
- launch_parallel_residual_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
- launch_params, configure_params); \
- } \
- static FwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
- ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- #define REGISTER_PARALLEL_BWD_LAUNCHER( \
- HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
- void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
- const bool configure_params) { \
- launch_parallel_residual_<WTYPE, \
- ITYPE, \
- RTYPE, \
- OTYPE, \
- CTYPE, \
- uint32_t, \
- HIDDEN_SIZE, \
- CTAS_PER_ROW, \
- WARPS_M, \
- WARPS_N, \
- BYTES_PER_LDG, \
- BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
- } \
- static BwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
- ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- inline __device__ float2 operator+(const float2 & a, const float2 & b){
- return {a.x + b.x, a.y + b.y};
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- inline __device__ void operator+=(float2 & a, const float2 & b){
- a.x += b.x;
- a.y += b.y;
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T>
- struct Sum {
- inline __device__ Sum(){}
- inline __device__ T operator()(const T &a, const T &b){
- return a + b;
- }
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T>
- inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
- return __shfl_xor_sync(uint32_t(-1), x, idx);
- }
- template<>
- inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
- return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
- }
- template<typename T>
- inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
- return __shfl_down_sync(uint32_t(-1), x, idx);
- }
- template<>
- inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
- return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- namespace layer_norm {
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- struct uint16 {
- uint4 u;
- uint4 v;
- uint4 s;
- uint4 t;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- struct uint8 {
- uint4 u;
- uint4 v;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<int BYTES>
- struct BytesToType {};
- template<>
- struct BytesToType<64> {
- using Type = uint16;
- static_assert(sizeof(Type) == 64);
- };
- template<>
- struct BytesToType<32> {
- using Type = uint8;
- static_assert(sizeof(Type) == 32);
- };
- template<>
- struct BytesToType<16> {
- using Type = uint4;
- static_assert(sizeof(Type) == 16);
- };
- template<>
- struct BytesToType<8> {
- using Type = uint64_t;
- static_assert(sizeof(Type) == 8);
- };
- template<>
- struct BytesToType<4> {
- using Type = uint32_t;
- static_assert(sizeof(Type) == 4);
- };
- template<>
- struct BytesToType<2> {
- using Type = uint16_t;
- static_assert(sizeof(Type) == 2);
- };
- template<>
- struct BytesToType<1> {
- using Type = uint8_t;
- static_assert(sizeof(Type) == 1);
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T>
- struct TypeToVec2 {};
- template<>
- struct TypeToVec2<float> {
- using Type = float2;
- };
- template<>
- struct TypeToVec2<half> {
- using Type = half2;
- };
- template<>
- struct TypeToVec2<nv_bfloat16> {
- using Type = nv_bfloat162;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<int INDEX>
- struct Get {
- template<typename T, typename R>
- static inline __device__ R of(const T &vec);
- };
- template<>
- template<typename T, typename R>
- inline __device__ R Get<0>::of(const T &vec) {
- return vec.x;
- }
- template<>
- template<typename T, typename R>
- inline __device__ R Get<1>::of(const T &vec) {
- return vec.y;
- }
- template<>
- template<typename T, typename R>
- inline __device__ R Get<2>::of(const T &vec) {
- return vec.z;
- }
- template<>
- template<typename T, typename R>
- inline __device__ R Get<3>::of(const T &vec) {
- return vec.w;
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename Src, typename Dst>
- struct Converter{
- static inline __device__ Dst convert(const Src &from) {
- return Dst(from);
- }
- };
- template<>
- struct Converter<float2, half2>{
- static inline __device__ half2 convert(const float2 &x) {
- return __float22half2_rn(x);
- }
- };
- template<>
- struct Converter<float2, nv_bfloat162>{
- static inline __device__ nv_bfloat162 convert(const float2 &x) {
- #if __CUDA_ARCH__ >= 800
- return __float22bfloat162_rn(x);
- #else
- union {
- nv_bfloat162 raw;
- nv_bfloat16 x;
- nv_bfloat16 y;
- } tmp;
- tmp.x = __float2bfloat16_rn(x.x);
- tmp.y = __float2bfloat16_rn(x.y);
- return tmp.raw;
- #endif
- }
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T>
- struct Zeros{
- static inline __device__ T get() {
- return T(0.f);
- }
- };
- template<>
- struct Zeros<float2>{
- static inline __device__ float2 get() {
- return make_float2(0.f, 0.f);
- }
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename Elt_type, uint32_t NUM_ELT>
- struct Vec {
- enum { BYTES = NUM_ELT * sizeof(Elt_type) };
- using Vec_type = typename BytesToType<BYTES>::Type;
- using Alias_type = union {
- Vec_type vec;
- Elt_type elt[NUM_ELT];
- };
- Alias_type data;
- template<typename S>
- inline __device__ void to(Vec<S, NUM_ELT> &other) {
- #pragma unroll
- for( int it = 0; it < NUM_ELT; it++ ) {
- other.data.elt[it] = S(this->data.elt[it]);
- }
- }
- template<typename Op>
- inline __device__ void assign(const Op &op) {
- #pragma unroll
- for( int it = 0; it < NUM_ELT; it++ ) {
- this->data.elt[it] = op(it);
- }
- }
- inline __device__ void zero_() {
- #pragma unroll
- for( int it = 0; it < NUM_ELT; it++ ) {
- this->data.elt[it] = Elt_type(0.f);
- }
- }
- inline __device__ void load_from(const void *base_ptr, const size_t idx) {
- this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
- }
- inline __device__ void store_to(void *base_ptr, const size_t idx) {
- static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
- }
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<uint32_t CTAS_PER_ROW>
- struct InterCTASync {
- template<typename Params>
- inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
- : phase_counter_(0)
- , b0_(params.barrier + bidm) // The barrier for this group of CTAs.
- , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
- {
- // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
- }
- inline __device__ void spin_wait_(int *barrier, int step, int expected) {
- asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
- for( int found = -1; found != expected; ) {
- asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
- }
- }
- inline __device__ void sync(){
- // ALL THREADS MUST ENTER!
- // We switch barrier every iteration.
- int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
- // We decrement every other iteration.
- bool dec = phase_counter_ & 0x2;
- int step = dec ? -1 : 1;
- int expected = dec ? 0 : CTAS_PER_ROW;
- // There are only 4 phases: up/down for b0/b1.
- phase_counter_ = (phase_counter_ + 1) & 0x3;
- if( threadIdx.x == 0 ) {
- spin_wait_(barrier, step, expected);
- }
- // CTA waits for thread 0
- __syncthreads();
- }
- int phase_counter_;
- int * b0_;
- int * b1_;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
- struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
- using InterCTASync = InterCTASync<CTAS_PER_ROW>;
- using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
- using Type = typename Base::Type;
- enum { SMEM_BYTES = Base::SMEM_BYTES };
- enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
- enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
- // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
- enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
- template<typename Params>
- inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
- : Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
- , inter_cta_(params, bidm, bidn)
- , bidn_(bidn) // CTA id within the group.
- , w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
- , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
- {
- }
- template<typename Op>
- inline __device__ T allreduce(T data, Op &op) {
- data = Base::reduce(data, op);
- // We switch workspace every iteration.
- T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
- // Warp leaders 0 hold the CTA-local results.
- if( this->warp_n_ == 0 && this->lane_ == 0 ) {
- workspace[bidn_] = data;
- }
- inter_cta_.sync();
- static_assert(CTAS_PER_ROW <= 32);
- T total = Zeros<T>::get();
- if(this->lane_ < CTAS_PER_ROW){
- total = workspace[this->lane_];
- }
- total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
- return total;
- }
- InterCTASync inter_cta_;
- T *w0_;
- T *w1_;
- int bidn_;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T, uint32_t WARPS_M>
- struct Reducer<T, 1, WARPS_M, 1> {
- using Type = T;
- enum { SMEM_BYTES = 0 };
- enum { WORKSPACE_BYTES_PER_GROUP = 0 };
- enum { THREADS_PER_WARP = 32 };
- template<typename Params>
- inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
- : warp_n_(warp_n)
- , lane_(lane)
- {
- }
- template<typename Op>
- static inline __device__ T allreduce_(T data, Op &op) {
- #pragma unroll
- for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
- data = op(data, warp_shuffle_xor(data, it));
- }
- return data;
- }
- template<typename Op>
- inline __device__ T allreduce(T data, Op &op) {
- return allreduce_(data, op);
- }
- template<typename Op>
- inline __device__ T reduce(T data, Op &op){
- // only lane 0 holds the result!
- #pragma unroll
- for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
- data = op(data, warp_shuffle_down(data, it));
- }
- return data;
- }
- int warp_n_;
- int lane_;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
- struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
- using Base = Reducer<T, 1, WARPS_M, 1>;
- using Type = T;
- enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
- enum { WORKSPACE_BYTES_PER_GROUP = 0 };
- enum { THREADS_PER_WARP = 32 };
- template<typename Params>
- inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
- : Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
- , use0_(true)
- {
- smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
- smem1_ = smem0_ + WARPS_M * WARPS_N;
- }
- template<typename Op>
- inline __device__ T allreduce(T data, Op & op) {
- T * smem = use0_ ? smem0_ : smem1_;
- use0_ = !use0_;
- data = Base::reduce(data, op);
- if( this->lane_ == 0 ) {
- smem[this->warp_n_] = data;
- }
- __syncthreads();
- T out = Zeros<T>::get();
- #pragma unroll
- for( int it = 0; it < WARPS_N; it++ ) {
- out = op(out, smem[it]);
- }
- return out;
- }
- template<typename Op>
- inline __device__ T reduce(T data, Op &op) {
- T * smem = use0_ ? smem0_ : smem1_;
- use0_ = !use0_;
- // only intra-CTA group leader holds the result!
- data = Base::reduce(data, op);
- if( this->lane_ == 0 ) {
- smem[this->warp_n_] = data;
- }
- __syncthreads();
- T out = Zeros<T>::get();
- if( this->warp_n_ == 0 && this->lane_ == 0 ) {
- #pragma unroll
- for( int it = 0; it < WARPS_N; it++ ) {
- out = op(out, smem[it]);
- }
- }
- return out;
- }
- T * smem0_;
- T * smem1_;
- bool use0_;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
-
- template<typename T, typename int_t>
- inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){
- //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
- const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
-
- #pragma unroll
- for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
- // Exchange
- int_t n_b = warp_shuffle_down(n_a, step);
- T m_b = warp_shuffle_down(m_a, step);
- T m2_b = warp_shuffle_down(m2_a, step);
- // Update
- const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both.
- const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
- const T delta = m_a - m_b;
- const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
- const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
- n_a = n_ab;
- m_a = m_ab;
- m2_a = m2_ab;
- }
- // Intra-warp broadcast (only lane 0 has valid stats).
- m_a = __shfl_sync(uint32_t(-1), m_a, 0);
- m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
- struct Stats {
- // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
- using InterCTASync = InterCTASync<CTAS_PER_ROW>;
- using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
- using stats_t = typename BlockStats::stats_t;
- enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
- template<typename Params>
- inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
- : inter_cta_(params, bidm, bidn)
- , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
- , bidn_(bidn) // CTA id within the group.
- , w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
- , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
- , warp_n_(warp_n)
- , lane_(lane)
- {
- }
- template<uint32_t N>
- inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
- constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
- // TODO rn is not really needed here..
- constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
- stats_t block_stats = block_stats_.compute(elts, block_rn);
- stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
- if( warp_n_ == 0 && lane_ == 0 ) {
- workspace[bidn_] = block_stats;
- }
- // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
- inter_cta_.sync();
- T n = Zeros<T>::get();
- T m = Zeros<T>::get();
- T m2 = Zeros<T>::get();
- // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
- static_assert(CTAS_PER_ROW <= 32);
- // Every warp does the final reduction locally.
- if( lane_ < CTAS_PER_ROW ) {
- stats_t result = workspace[lane_];
- n = ELTS_PER_ROW_PER_CTA;
- m = layer_norm::Get<0>::of<stats_t, T>(result);
- m2 = layer_norm::Get<1>::of<stats_t, T>(result);
- }
- warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
- return { m, m2 };
- }
- InterCTASync inter_cta_;
- BlockStats block_stats_;
- stats_t *w0_;
- stats_t *w1_;
- int bidn_;
- int warp_n_;
- int lane_;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
- struct Stats<T, 1, WARPS_M, WARPS_N> {
- using WarpStats = Stats<T, 1, WARPS_M, 1>;
- using stats_t = typename WarpStats::stats_t;
- enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
- template<typename Params>
- inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
- : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
- , use0_(true)
- {
- smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
- smem1_ = smem0_ + WARPS_M * WARPS_N;
- }
- template<bool Is_even_cols, uint32_t N, typename function_t>
- inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
- function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
- stats_t * smem = use0_ ? smem0_ : smem1_;
- use0_ = !use0_;
- // Compute warp local for all WARPS_N
- const auto warp_n = warp_stats_.reducer_.warp_n_;
- const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n));
- stats_t warp_stats = warp_stats_.template compute<Is_even_cols>(
- elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts
- );
- //Each warp warp leader stores its stats
- const auto lane = warp_stats_.reducer_.lane_;
- if( lane == 0 ) {
- smem[warp_n] = warp_stats;
- }
- __syncthreads();
- int n = 0;;
- T m = Zeros<T>::get();
- T m2 = Zeros<T>::get();
- // Assume that there are less than 32 warps, such that we can finalize with a single warp
- static_assert(WARPS_N <= 32);
- if(lane < WARPS_N){
- stats_t result = smem[lane];
- n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane);
- m = layer_norm::Get<0>::of<stats_t, T>(result);
- m2 = layer_norm::Get<1>::of<stats_t, T>(result);
- }
- warp_chan_upd_dynamic(m, m2, n, WARPS_N);
- return { m, m2 };
- }
- WarpStats warp_stats_;
- stats_t * smem0_;
- stats_t * smem1_;
- bool use0_;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T, uint32_t WARPS_M>
- struct Stats<T, 1, WARPS_M, 1> {
- using stats_t = typename TypeToVec2<T>::Type;
- // The simple Warp reducer.
- using Reducer = Reducer<T, 1, WARPS_M, 1>;
- enum { SMEM_BYTES = 0 };
- template<typename Params>
- inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
- : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
- {
- }
- template<bool Is_even_cols, uint32_t N, typename function_t>
- inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
- // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) {
- function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
- auto sum = Sum<T>();
- T m = Zeros<T>::get();
- #pragma unroll
- for( int it = 0; it < N; it++ ) {
- if (Is_even_cols || (it < num_valid_elts)) {
- m += elts[it];
- }
- }
- m = reducer_.allreduce(m, sum) * row_norm_factor;
- T m2 = Zeros<T>::get();
- #pragma unroll
- for( int it = 0; it < N; it++ ) {
- if (Is_even_cols || (it < num_valid_elts)) {
- T diff = (elts[it] - m);
- m2 += diff * diff;
- }
- }
- m2 = reducer_.allreduce(m2, sum);
- return {m, m2};
- }
- Reducer reducer_;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- } // namespace layer_norm
|