Sfoglia il codice sorgente

Remove torchlib dependency from cpp files (#1083)

Cameron Shinn 8 mesi fa
parent
commit
cb516f855b
4 ha cambiato i file con 25 aggiunte e 26 eliminazioni
  1. 0 11
      hopper/flash.h
  2. 8 9
      hopper/flash_bwd_launch_template.h
  3. 4 5
      hopper/flash_fwd_launch_template.h
  4. 13 1
      hopper/utils.h

+ 0 - 11
hopper/flash.h

@@ -7,14 +7,6 @@
 #include <cuda.h>
 #include <vector>
 
-#ifdef OLD_GENERATOR_PATH
-#include <ATen/CUDAGeneratorImpl.h>
-#else
-#include <ATen/cuda/CUDAGeneratorImpl.h>
-#endif
-
-#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
-
 #include "cutlass/fast_math.h"  // For cutlass::FastDivmod
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -118,9 +110,6 @@ struct Flash_fwd_params : public Qkv_params {
     // Local window size
     int window_size_left, window_size_right;
 
-    // Random state.
-    at::PhiloxCudaState philox_args;
-
     // Pointer to the RNG seed (idx 0) and offset (idx 1).
     uint64_t * rng_state;
 

+ 8 - 9
hopper/flash_bwd_launch_template.h

@@ -4,8 +4,6 @@
 
 #pragma once
 
-#include <ATen/cuda/CUDAContext.h>
-
 #include "cute/tensor.hpp"
 
 #include "cutlass/cluster_launch.hpp"
@@ -15,6 +13,7 @@
 #include "flash_bwd_preprocess_kernel.h"
 #include "flash_bwd_kernel.h"
 #include "kernel_traits.h"
+#include "utils.h"
 
 template<bool Clear_dQaccum=true, typename Kernel_traits>
 __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
@@ -38,7 +37,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
     // If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
     // flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
+    CHECK_CUDA_KERNEL_LAUNCH();
 
     using Element = typename Kernel_traits::Element;
     using ElementAccum = typename Kernel_traits::ElementAccum;
@@ -157,7 +156,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     // printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
     // printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
     if (smem_size >= 48 * 1024) {
-       C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+       CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
     }
 
     static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
@@ -179,7 +178,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     }
     // cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
                                       // tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
+    CHECK_CUDA_KERNEL_LAUNCH();
 
     auto tma_load_dQaccum = make_tma_copy(
         typename cute::SM90_TMA_LOAD{},
@@ -190,20 +189,20 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     // auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
     auto kernel_dq = &flash::convert_dQ<Kernel_traits, decltype(tma_load_dQaccum)>;
     if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024)  {
-        C10_CUDA_CHECK(cudaFuncSetAttribute(
+        CHECK_CUDA(cudaFuncSetAttribute(
             kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8));
     }
     kernel_dq<<<grid_m, Kernel_traits::kNThreadsdQ, Kernel_traits::kSmemdQSize * 2 + 8, stream>>>(params, tma_load_dQaccum);
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
+    CHECK_CUDA_KERNEL_LAUNCH();
     // auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
     // if (Kernel_traits::kSmemdKVSize >= 48 * 1024)  {
-        // C10_CUDA_CHECK(cudaFuncSetAttribute(
+        // CHECK_CUDA(cudaFuncSetAttribute(
             // kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
     // }
     // int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
     // dim3 grid_n(num_n_block, params.b, params.h);
     // kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
-    // C10_CUDA_KERNEL_LAUNCH_CHECK();
+    // CHECK_CUDA_KERNEL_LAUNCH();
 }
 
 

+ 4 - 5
hopper/flash_fwd_launch_template.h

@@ -4,8 +4,6 @@
 
 #pragma once
 
-#include <ATen/cuda/CUDAContext.h>
-
 #include "cute/tensor.hpp"
 
 #include "cutlass/cutlass.h"
@@ -16,6 +14,7 @@
 #include "tile_scheduler.hpp"
 #include "flash_fwd_kernel.h"
 #include "kernel_traits.h"
+#include "utils.h"
 
 
 template<typename Kernel_traits, bool Is_causal>
@@ -66,7 +65,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
     // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
     if (smem_size >= 48 * 1024) {
-       C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+       CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
     }
 
     int device;
@@ -75,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     cudaError status_ = cudaDeviceGetAttribute(
         &multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
     if (status_ != cudaSuccess) {
-      C10_CUDA_CHECK(status_);
+      CHECK_CUDA(status_);
     }
     dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
     static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
@@ -83,7 +82,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
     cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
     cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params);
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
+    CHECK_CUDA_KERNEL_LAUNCH();
 }
 
 template<typename T>

+ 13 - 1
hopper/utils.h

@@ -21,6 +21,18 @@
 #include <cutlass/numeric_conversion.h>
 #include <cutlass/numeric_types.h>
 
+#define CHECK_CUDA(call)                                                                                  \
+    do {                                                                                                  \
+        cudaError_t status_ = call;                                                                       \
+        if (status_ != cudaSuccess) {                                                                     \
+            fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
+            exit(1);                                                                                      \
+        }                                                                                                 \
+    } while(0)
+
+#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
+
+
 namespace flash {
 
 using namespace cute;
@@ -62,7 +74,7 @@ struct Allreduce {
 
 template<>
 struct Allreduce<2> {
-template<typename T, typename Operator> 
+template<typename T, typename Operator>
 static __device__ __forceinline__ T run(T x, Operator &op) {
     x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
     return x;