|
@@ -7,7 +7,7 @@
|
|
|
|
|
|
namespace layer_norm {
|
|
|
|
|
|
-template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Is_even_cols>
|
|
|
+template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
|
|
|
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
|
|
void ln_bwd_kernel(layer_norm::BwdParams params) {
|
|
|
|
|
@@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
|
|
|
|
|
Cvec dzy_sum[LDGS];
|
|
|
Cvec dz_sum[LDGS];
|
|
|
+ Cvec dcolscale_sum[LDGS];
|
|
|
|
|
|
memset(dzy_sum, 0, sizeof(dzy_sum));
|
|
|
memset(dz_sum, 0, sizeof(dz_sum));
|
|
|
+ if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
|
|
|
|
|
|
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
|
|
|
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
|
|
@@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
|
|
((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
|
|
|
|
|
|
Wvec gamma[LDGS];
|
|
|
+ Wvec colscale[LDGS];
|
|
|
index_t idx = c;
|
|
|
#pragma unroll
|
|
|
for( int it = 0; it < LDGS; it++ ) {
|
|
|
if (Is_even_cols || (it < num_valid_ldgs)) {
|
|
|
gamma[it].load_from(params.gamma, idx);
|
|
|
+ if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
|
|
|
idx += Ktraits::VEC_COLS_PER_LDG;
|
|
|
}
|
|
|
}
|
|
@@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
|
|
if (Is_even_cols || (it < num_valid_ldgs)) {
|
|
|
Ivec dx0;
|
|
|
Rvec dx1;
|
|
|
+ Ivec x0;
|
|
|
+ if (Has_colscale) { x0.load_from(params.x0, idx); }
|
|
|
#pragma unroll
|
|
|
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
|
|
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
|
|
@@ -140,9 +146,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
|
|
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
|
|
|
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
|
|
|
if (Is_dropout) {
|
|
|
- dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
|
|
|
+ dx0_tmp_res *= params.dropout_scale;
|
|
|
+ if (Has_colscale) {
|
|
|
+ dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
|
|
|
+ dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
|
|
|
+ } else {
|
|
|
+ dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
|
|
|
+ }
|
|
|
} else {
|
|
|
- dx0.data.elt[jt] = dx0_tmp_res;
|
|
|
+ if (Has_colscale) {
|
|
|
+ dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
|
|
|
+ dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
|
|
|
+ } else {
|
|
|
+ dx0.data.elt[jt] = dx0_tmp_res;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
if (Has_residual) { dx1.store_to(params.dx1, idx); }
|
|
@@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
|
|
if (Is_even_cols || (it < num_valid_ldgs)) {
|
|
|
dz_sum[it].store_to(params.dbeta_part, idx);
|
|
|
dzy_sum[it].store_to(params.dgamma_part, idx);
|
|
|
+ if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
|
|
|
idx += Ktraits::VEC_COLS_PER_LDG;
|
|
|
}
|
|
|
}
|
|
@@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ compute_t cta_dcolscale_sum[NUM_RES];
|
|
|
+ if (Has_colscale) {
|
|
|
+ __syncthreads();
|
|
|
+ idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
|
|
+ #pragma unroll
|
|
|
+ for( int it = 0; it < LDGS; it++ ) {
|
|
|
+ dcolscale_sum[it].store_to(smem_wgrad, idx);
|
|
|
+ idx += THREADS_PER_ROW;
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
|
|
|
+ for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
|
|
+ for( int jt = 0; jt < NUM_RES; jt++ ) {
|
|
|
+ cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
const index_t num_valid_writes
|
|
|
= (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
|
|
|
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
|
|
|
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
|
|
|
+ compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
|
|
|
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
|
|
if (Is_even_cols || (jt < num_valid_writes)) {
|
|
|
*dgamma_part = cta_dzy_sum[jt];
|
|
|
dgamma_part += Ktraits::THREADS_PER_CTA;
|
|
|
*dbeta_part = cta_dz_sum[jt];
|
|
|
dbeta_part += Ktraits::THREADS_PER_CTA;
|
|
|
+ if (Has_colscale) {
|
|
|
+ *dcolscale_part = cta_dcolscale_sum[jt];
|
|
|
+ dcolscale_part += Ktraits::THREADS_PER_CTA;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template<typename Kernel_traits, bool Is_even_cols>
|
|
|
+template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
|
|
|
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
|
|
|
void ln_bwd_finalize_kernel(BwdParams params)
|
|
|
{
|
|
@@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
|
|
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
|
|
|
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
|
|
|
// Each thread sums over NUM_ELT columns.
|
|
|
- Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
|
|
|
+ Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
|
|
|
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
|
|
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
|
|
+ if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
|
|
|
if (Is_even_cols || col < params.cols) {
|
|
|
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
|
|
|
- // index_t idx = row * Kernel_traits::COLS + col;
|
|
|
index_t idx = row * params.cols + col;
|
|
|
|
|
|
- Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
|
|
|
+ Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
|
|
|
dbeta_part.load_from(params.dbeta_part, idx);
|
|
|
dgamma_part.load_from(params.dgamma_part, idx);
|
|
|
+ if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
|
|
|
#pragma unroll
|
|
|
for( int it = 0; it < NUM_ELT; it++ ) {
|
|
|
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
|
|
|
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
|
|
|
+ if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
void * smem_gamma = smem_;
|
|
|
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
|
|
+ void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
|
|
|
|
|
const int write_row = warp;
|
|
|
const int write_col = lane ^ write_row;
|
|
@@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
|
|
|
|
|
dgamma_local.store_to(smem_gamma, write_idx);
|
|
|
dbeta_local.store_to(smem_beta, write_idx);
|
|
|
+ if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
|
|
|
- void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
|
|
- void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
|
|
|
+ void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
|
|
+ void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
|
|
|
+ void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
|
|
|
|
|
|
|
|
|
// More than one iter iff ROWS_PER_CTA < 32.
|
|
@@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
|
|
|
|
|
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
|
|
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
|
|
+ if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
|
|
|
|
|
|
// Load beta and gamma transposed
|
|
|
if(read_row < Kernel_traits::ROWS_PER_CTA){
|
|
|
dbeta_local.load_from(smem_beta, read_idx);
|
|
|
dgamma_local.load_from(smem_gamma, read_idx);
|
|
|
+ if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
|
|
|
}
|
|
|
|
|
|
// Call reducer on the loaded value(s) and convert.
|
|
@@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
|
|
|
|
|
dgamma_local.data.elt[it] = g_i;
|
|
|
dbeta_local.data.elt[it] = b_i;
|
|
|
+ if (Has_colscale) {
|
|
|
+ compute_t cs_i = dcolscale_local.data.elt[it];
|
|
|
+ cs_i = reducer.allreduce(cs_i, sum);
|
|
|
+ dcolscale_local.data.elt[it] = cs_i;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// Leader stores the result at the current column.
|
|
|
if(lane == 0){
|
|
|
dgamma_local.store_to(smem_gamma_out, w);
|
|
|
dbeta_local.store_to(smem_beta_out, w);
|
|
|
+ if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
|
|
|
}
|
|
|
|
|
|
}
|
|
@@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
|
|
|
|
|
using src_t = typename TypeToVec2<compute_t>::Type;
|
|
|
using dst_t = typename TypeToVec2<weight_t>::Type;
|
|
|
- Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
|
|
|
- Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
|
|
|
+ Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
|
|
|
+ Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
|
|
|
|
|
|
dgamma_vec2.load_from(smem_gamma_out, lane);
|
|
|
dbeta_vec2.load_from(smem_beta_out, lane);
|
|
|
+ if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
|
|
|
#pragma unroll
|
|
|
for( int it = 0; it < NUM_ELT; it++ ) {
|
|
|
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
|
|
|
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
|
|
|
+ if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
|
|
|
}
|
|
|
dgamma_out2.store_to(params.dgamma, col_out);
|
|
|
dbeta_out2.store_to(params.dbeta, col_out);
|
|
|
-
|
|
|
+ if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -364,7 +420,7 @@ template<
|
|
|
int BYTES_PER_LDG_MAIN,
|
|
|
int BYTES_PER_LDG_FINAL
|
|
|
>
|
|
|
-void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
|
|
|
+void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
|
|
|
|
|
|
using Kernel_traits = Kernel_traits<weight_t,
|
|
|
input_t,
|
|
@@ -378,59 +434,64 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
|
|
|
WARPS_N,
|
|
|
BYTES_PER_LDG_MAIN
|
|
|
>;
|
|
|
+ bool prenorm = launch_params.params.dx != nullptr;
|
|
|
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
|
|
|
bool has_residual = launch_params.params.dx1 != nullptr;
|
|
|
+ bool has_colscale = launch_params.params.colscale != nullptr;
|
|
|
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
|
|
|
BOOL_SWITCH(prenorm, PrenormConst, [&] {
|
|
|
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
|
|
|
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
|
|
|
- BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
|
|
- auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
|
|
|
- if( configure_params ) {
|
|
|
- int ctas_per_sm;
|
|
|
- CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
|
- &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
|
|
|
- launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
|
|
- launch_params.barrier_size = 0;
|
|
|
- launch_params.workspace_bytes = 0;
|
|
|
- if(Kernel_traits::CTAS_PER_ROW > 1) {
|
|
|
- launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
|
|
- launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
|
|
- * Kernel_traits::WARPS_M
|
|
|
- * Kernel_traits::CTAS_PER_ROW
|
|
|
- * sizeof(typename Kernel_traits::reduce_t)
|
|
|
- * 2;
|
|
|
+ BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
|
|
|
+ BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
|
|
+ auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
|
|
|
+ if( configure_params ) {
|
|
|
+ int ctas_per_sm;
|
|
|
+ CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
|
+ &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
|
|
|
+ launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
|
|
+ launch_params.barrier_size = 0;
|
|
|
+ launch_params.workspace_bytes = 0;
|
|
|
+ if(Kernel_traits::CTAS_PER_ROW > 1) {
|
|
|
+ launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
|
|
+ launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
|
|
+ * Kernel_traits::WARPS_M
|
|
|
+ * Kernel_traits::CTAS_PER_ROW
|
|
|
+ * sizeof(typename Kernel_traits::reduce_t)
|
|
|
+ * 2;
|
|
|
+ }
|
|
|
+ return;
|
|
|
}
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
|
|
|
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
|
|
|
- }
|
|
|
- auto stream = launch_params.stream;
|
|
|
- auto ctas_per_col = launch_params.params.ctas_per_col;
|
|
|
|
|
|
- if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
|
|
- kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
|
|
|
- } else {
|
|
|
- dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
|
|
- dim3 block(Kernel_traits::THREADS_PER_CTA);
|
|
|
- void *params_ = (void *)&launch_params.params;
|
|
|
- cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream);
|
|
|
- }
|
|
|
+ if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
|
|
|
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
|
|
|
+ }
|
|
|
+ auto stream = launch_params.stream;
|
|
|
+ auto ctas_per_col = launch_params.params.ctas_per_col;
|
|
|
+
|
|
|
+ if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
|
|
+ kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
|
|
|
+ } else {
|
|
|
+ dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
|
|
+ dim3 block(Kernel_traits::THREADS_PER_CTA);
|
|
|
+ void *params_ = (void *)&launch_params.params;
|
|
|
+ cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream);
|
|
|
+ }
|
|
|
|
|
|
- using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
|
|
|
- weight_t,
|
|
|
- input_t,
|
|
|
- residual_t,
|
|
|
- output_t,
|
|
|
- compute_t,
|
|
|
- index_t,
|
|
|
- 32 * 32, // THREADS_PER_CTA
|
|
|
- BYTES_PER_LDG_FINAL>;
|
|
|
-
|
|
|
- auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>;
|
|
|
- kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
|
|
|
+ using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
|
|
|
+ weight_t,
|
|
|
+ input_t,
|
|
|
+ residual_t,
|
|
|
+ output_t,
|
|
|
+ compute_t,
|
|
|
+ index_t,
|
|
|
+ HasColscaleConst,
|
|
|
+ 32 * 32, // THREADS_PER_CTA
|
|
|
+ BYTES_PER_LDG_FINAL>;
|
|
|
+
|
|
|
+ auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
|
|
|
+ kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
|
|
|
+ });
|
|
|
});
|
|
|
});
|
|
|
});
|