Browse Source

[LayerNorm] Check cuda error after querying ctas_per_sm

Tri Dao 2 năm trước cách đây
mục cha
commit
e4d3013e15

+ 2 - 2
csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu

@@ -44,8 +44,8 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
                     auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
                     if( configure_params ) {
                         int ctas_per_sm;
-                        cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
-                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
+                        CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
                         launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
                         launch_params.barrier_size = 0;
                         launch_params.workspace_bytes = 0;

+ 2 - 2
csrc/layer_norm/ln_fwd_cuda_kernel.cu

@@ -41,8 +41,8 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
                 auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
                 if( configure_params ) {
                     int ctas_per_sm;
-                    cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
-                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
+                    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
                     launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
                     const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
                     launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;