Преглед изворни кода

Minor cleanup of softcapping

Tri Dao пре 8 месеци
родитељ
комит
1d536d7de5
2 измењених фајлова са 4 додато и 12 уклоњено
  1. 2 2
      csrc/flash_attn/flash_api.cpp
  2. 2 10
      csrc/flash_attn/src/flash_fwd_kernel.h

+ 2 - 2
csrc/flash_attn/flash_api.cpp

@@ -106,9 +106,9 @@ void set_params_fprop(Flash_fwd_params &params,
     #endif
     if (softcap > 0.0) {
         params.softcap = softmax_scale / softcap;
-        params.scale_softmax =  softcap;
+        params.scale_softmax = softcap;
         params.scale_softmax_log2 = softcap * M_LOG2E;
-    }else{
+    } else{
         // Remove potential NaN
         params.softcap = 0.0;
         params.scale_softmax = softmax_scale;

+ 2 - 10
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -24,17 +24,9 @@ using namespace cute;
 
 template <typename Engine, typename Layout>
 __forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
-    static_assert(Layout::rank == 3, "Only support 3D Tensor");
-    static_assert(decltype(size<0>(tensor))::value == 4, "First dimension must be 4");
     #pragma unroll
-    for (int i=0; i < size<0>(tensor); ++i){  // MMA
-        #pragma unroll
-        for (int mi=0; mi < size<1>(tensor); ++mi){
-            #pragma unroll
-            for (int nj=0; nj < size<2>(tensor); ++nj){
-                tensor(i, mi, nj) = cutlass::fast_tanh(tensor(i, mi, nj) * softcap );
-            }
-        }
+    for (int i = 0; i < size(tensor); ++i) {
+        tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
     }
 }