|
@@ -4,8 +4,6 @@
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
-#include <ATen/cuda/CUDAContext.h>
|
|
|
-
|
|
|
#include "cute/tensor.hpp"
|
|
|
|
|
|
#include "cutlass/cluster_launch.hpp"
|
|
@@ -15,6 +13,7 @@
|
|
|
#include "flash_bwd_preprocess_kernel.h"
|
|
|
#include "flash_bwd_kernel.h"
|
|
|
#include "kernel_traits.h"
|
|
|
+#include "utils.h"
|
|
|
|
|
|
template<bool Clear_dQaccum=true, typename Kernel_traits>
|
|
|
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
|
|
@@ -38,7 +37,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|
|
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
|
|
|
// If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
|
|
|
// flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
|
|
|
- C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
|
+ CHECK_CUDA_KERNEL_LAUNCH();
|
|
|
|
|
|
using Element = typename Kernel_traits::Element;
|
|
|
using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
@@ -157,7 +156,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|
|
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
|
|
|
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
|
|
|
if (smem_size >= 48 * 1024) {
|
|
|
- C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
|
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
|
}
|
|
|
|
|
|
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
|
@@ -179,7 +178,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|
|
}
|
|
|
// cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
|
|
|
// tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
|
|
|
- C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
|
+ CHECK_CUDA_KERNEL_LAUNCH();
|
|
|
|
|
|
auto tma_load_dQaccum = make_tma_copy(
|
|
|
typename cute::SM90_TMA_LOAD{},
|
|
@@ -190,20 +189,20 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|
|
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
|
|
auto kernel_dq = &flash::convert_dQ<Kernel_traits, decltype(tma_load_dQaccum)>;
|
|
|
if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024) {
|
|
|
- C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
|
+ CHECK_CUDA(cudaFuncSetAttribute(
|
|
|
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8));
|
|
|
}
|
|
|
kernel_dq<<<grid_m, Kernel_traits::kNThreadsdQ, Kernel_traits::kSmemdQSize * 2 + 8, stream>>>(params, tma_load_dQaccum);
|
|
|
- C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
|
+ CHECK_CUDA_KERNEL_LAUNCH();
|
|
|
// auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
|
|
|
// if (Kernel_traits::kSmemdKVSize >= 48 * 1024) {
|
|
|
- // C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
|
+ // CHECK_CUDA(cudaFuncSetAttribute(
|
|
|
// kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
|
|
|
// }
|
|
|
// int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
|
|
|
// dim3 grid_n(num_n_block, params.b, params.h);
|
|
|
// kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
|
|
|
- // C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
|
+ // CHECK_CUDA_KERNEL_LAUNCH();
|
|
|
}
|
|
|
|
|
|
|