scaled_masked_softmax.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. /* coding=utf-8
  2. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include <assert.h>
  18. #include <cuda_fp16.h>
  19. #include <cfloat>
  20. #include <limits>
  21. #include <stdint.h>
  22. #include <cuda_fp16.h>
  23. #include <c10/macros/Macros.h>
  24. namespace {
  25. template <typename Datatype, int ELEMENTS_PER_LDG>
  26. __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
  27. template <>
  28. __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
  29. template <>
  30. __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
  31. template <>
  32. __device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
  33. template <>
  34. __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
  35. template <>
  36. __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
  37. template <>
  38. __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
  39. int log2_ceil(int value) {
  40. int log2_value = 0;
  41. while ((1 << log2_value) < value) ++log2_value;
  42. return log2_value;
  43. }
  44. template<typename T>
  45. struct Add {
  46. __device__ __forceinline__ T operator()(T a, T b) const {
  47. return a + b;
  48. }
  49. };
  50. template<typename T>
  51. struct Max {
  52. __device__ __forceinline__ T operator()(T a, T b) const {
  53. return a < b ? b : a;
  54. }
  55. };
  56. template <typename T>
  57. __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
  58. {
  59. #if CUDA_VERSION >= 9000
  60. return __shfl_xor_sync(mask, value, laneMask, width);
  61. #else
  62. return __shfl_xor(value, laneMask, width);
  63. #endif
  64. }
  65. template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
  66. __device__ __forceinline__ void warp_reduce(acc_t* sum) {
  67. ReduceOp<acc_t> r;
  68. #pragma unroll
  69. for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
  70. #pragma unroll
  71. for (int i = 0; i < WARP_BATCH; ++i) {
  72. acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
  73. sum[i] = r(sum[i], b);
  74. }
  75. }
  76. }
  77. /*
  78. * Extended softmax (from native aten pytorch) with following additional features
  79. * 1) input scaling
  80. * 2) Explicit masking
  81. */
  82. template <typename input_t, typename output_t, typename acc_t, int log2_elements>
  83. __global__ void scaled_masked_softmax_warp_forward(
  84. output_t *dst,
  85. const input_t *src,
  86. const uint8_t *mask,
  87. const acc_t scale,
  88. int micro_batch_size,
  89. int element_count,
  90. int pad_batches)
  91. {
  92. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  93. // warp_size of method warp_softmax_forward_kernel.
  94. constexpr int next_power_of_two = 1 << log2_elements;
  95. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  96. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  97. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  98. constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
  99. // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
  100. // gridDim/blockIdx = (seq_len, attn_heads, batches)
  101. int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
  102. int pad_first_batch = 0;
  103. if (pad_batches != 1) { // bert style
  104. pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
  105. } else { // gpt2 style
  106. pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  107. }
  108. // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  109. // many batches have to computed within this WARP.
  110. int local_batches = micro_batch_size - first_batch;
  111. if (local_batches > WARP_BATCH)
  112. local_batches = WARP_BATCH;
  113. // there might be multiple batches per warp. compute the index within the batch
  114. int local_idx = threadIdx.x;
  115. src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  116. dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  117. mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  118. // load data from global memory
  119. acc_t elements[WARP_BATCH][WARP_ITERATIONS];
  120. input_t temp_data[ELEMENTS_PER_LDG_STG];
  121. uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
  122. #pragma unroll
  123. for (int i = 0; i < WARP_BATCH; ++i) {
  124. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  125. #pragma unroll
  126. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  127. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  128. if (element_index < batch_element_count) {
  129. int itr_idx = i*element_count+it*WARP_SIZE;
  130. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
  131. copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
  132. #pragma unroll
  133. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  134. if (temp_mask[element] != 1) {
  135. elements[i][it + element] = (acc_t)temp_data[element] * scale;
  136. } else {
  137. elements[i][it + element] = -10000.0;
  138. }
  139. }
  140. } else {
  141. #pragma unroll
  142. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  143. elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
  144. }
  145. }
  146. }
  147. }
  148. // compute max_value
  149. acc_t max_value[WARP_BATCH];
  150. #pragma unroll
  151. for (int i = 0; i < WARP_BATCH; ++i) {
  152. max_value[i] = elements[i][0];
  153. #pragma unroll
  154. for (int it = 1; it < WARP_ITERATIONS; ++it) {
  155. max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
  156. }
  157. }
  158. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
  159. // compute scale value to account for full mask
  160. acc_t scale_value[WARP_BATCH];
  161. #pragma unroll
  162. for (int i = 0; i < WARP_BATCH; ++i) {
  163. scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
  164. }
  165. acc_t sum[WARP_BATCH] { 0.0f };
  166. #pragma unroll
  167. for (int i = 0; i < WARP_BATCH; ++i) {
  168. #pragma unroll
  169. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  170. elements[i][it] = std::exp((elements[i][it] - max_value[i]));
  171. sum[i] += elements[i][it];
  172. }
  173. }
  174. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  175. // store result
  176. output_t out[ELEMENTS_PER_LDG_STG];
  177. #pragma unroll
  178. for (int i = 0; i < WARP_BATCH; ++i) {
  179. if (i >= local_batches)
  180. break;
  181. #pragma unroll
  182. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  183. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  184. if (element_index < element_count) {
  185. #pragma unroll
  186. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  187. out[element] = elements[i][it + element] * scale_value[i]/ sum[i];
  188. }
  189. copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
  190. } else {
  191. break;
  192. }
  193. }
  194. }
  195. }
  196. template <typename input_t, typename output_t, typename acc_t, int log2_elements>
  197. __global__ void scaled_masked_softmax_warp_backward(
  198. output_t *gradInput,
  199. input_t *grad,
  200. const input_t *output,
  201. acc_t scale,
  202. int micro_batch_size,
  203. int element_count)
  204. {
  205. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  206. // warp_size of method warp_softmax_backward_kernel.
  207. constexpr int next_power_of_two = 1 << log2_elements;
  208. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  209. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  210. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  211. constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
  212. // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
  213. // gridDim/blockIdx = (seq_len, attn_heads, batches)
  214. int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  215. // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  216. // many batches have to computed within this WARP.
  217. int local_batches = micro_batch_size - first_batch;
  218. if (local_batches > WARP_BATCH)
  219. local_batches = WARP_BATCH;
  220. // there might be multiple batches per warp. compute the index within the batch
  221. int local_idx = threadIdx.x;
  222. // the first element to process by the current thread
  223. int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  224. grad += thread_offset;
  225. output += thread_offset;
  226. gradInput += thread_offset;
  227. // load data from global memory
  228. acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
  229. acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
  230. input_t temp_grad[ELEMENTS_PER_LDG_STG];
  231. input_t temp_output[ELEMENTS_PER_LDG_STG];
  232. #pragma unroll
  233. for (int i = 0; i < WARP_BATCH; ++i) {
  234. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  235. #pragma unroll
  236. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  237. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  238. if (element_index < batch_element_count) {
  239. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
  240. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
  241. #pragma unroll
  242. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  243. output_reg[i][it + element] = (acc_t)temp_output[element];
  244. }
  245. #pragma unroll
  246. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  247. grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
  248. }
  249. }
  250. }
  251. }
  252. acc_t sum[WARP_BATCH];
  253. #pragma unroll
  254. for (int i = 0; i < WARP_BATCH; ++i) {
  255. sum[i] = grad_reg[i][0];
  256. #pragma unroll
  257. for (int it = 1; it < WARP_ITERATIONS; ++it) {
  258. sum[i] += grad_reg[i][it];
  259. }
  260. }
  261. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  262. // store result
  263. #pragma unroll
  264. for (int i = 0; i < WARP_BATCH; ++i) {
  265. if (i >= local_batches)
  266. break;
  267. #pragma unroll
  268. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  269. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  270. if (element_index < element_count) {
  271. // compute gradients
  272. output_t out[ELEMENTS_PER_LDG_STG];
  273. #pragma unroll
  274. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  275. out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
  276. }
  277. copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
  278. }
  279. }
  280. }
  281. }
  282. } // end of anonymous namespace
  283. int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
  284. int log2_elements = log2_ceil(key_seq_len);
  285. const int next_power_of_two = 1 << log2_elements;
  286. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  287. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  288. constexpr int threads_per_block = 128;
  289. int warps_per_block = (threads_per_block / warp_size);
  290. int batches_per_block = warps_per_block * batches_per_warp;
  291. return batches_per_block;
  292. }
  293. template<typename input_t, typename output_t, typename acc_t>
  294. void dispatch_scaled_masked_softmax_forward(
  295. output_t *dst,
  296. const input_t *src,
  297. const uint8_t *mask,
  298. const input_t scale,
  299. int query_seq_len,
  300. int key_seq_len,
  301. int batches,
  302. int attn_heads,
  303. int pad_batches)
  304. {
  305. TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 );
  306. if (key_seq_len == 0) {
  307. return;
  308. } else {
  309. int log2_elements = log2_ceil(key_seq_len);
  310. const int next_power_of_two = 1 << log2_elements;
  311. int batch_count = batches * attn_heads * query_seq_len;
  312. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
  313. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  314. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
  315. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  316. // use 128 threads per block to maximimize gpu utilization
  317. constexpr int threads_per_block = 128;
  318. int warps_per_block = (threads_per_block / warp_size);
  319. int batches_per_block = warps_per_block * batches_per_warp;
  320. TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
  321. dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
  322. dim3 threads(warp_size, warps_per_block, 1);
  323. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  324. switch (log2_elements) {
  325. case 0: // 1
  326. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
  327. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  328. break;
  329. case 1: // 2
  330. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
  331. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  332. break;
  333. case 2: // 4
  334. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
  335. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  336. break;
  337. case 3: // 8
  338. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
  339. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  340. break;
  341. case 4: // 16
  342. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
  343. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  344. break;
  345. case 5: // 32
  346. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
  347. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  348. break;
  349. case 6: // 64
  350. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
  351. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  352. break;
  353. case 7: // 128
  354. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
  355. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  356. break;
  357. case 8: // 256
  358. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
  359. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  360. break;
  361. case 9: // 512
  362. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
  363. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  364. break;
  365. case 10: // 1024
  366. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
  367. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  368. break;
  369. case 11: // 2048
  370. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
  371. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  372. break;
  373. case 12: // 4096
  374. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
  375. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  376. break;
  377. case 13: // 8192
  378. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
  379. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  380. break;
  381. default:
  382. break;
  383. }
  384. }
  385. }
  386. template<typename input_t, typename output_t, typename acc_t>
  387. void dispatch_scaled_masked_softmax_backward(
  388. output_t *grad_input,
  389. input_t *grad,
  390. const input_t *output,
  391. const acc_t scale,
  392. int query_seq_len,
  393. int key_seq_len,
  394. int batches,
  395. int attn_heads)
  396. {
  397. TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 );
  398. if (key_seq_len == 0) {
  399. return;
  400. } else {
  401. int log2_elements = log2_ceil(key_seq_len);
  402. const int next_power_of_two = 1 << log2_elements;
  403. int batch_count = batches * attn_heads * query_seq_len;
  404. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
  405. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  406. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
  407. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  408. // use 128 threads per block to maximimize gpu utilization
  409. constexpr int threads_per_block = 128;
  410. int warps_per_block = (threads_per_block / warp_size);
  411. int batches_per_block = warps_per_block * batches_per_warp;
  412. int blocks = batch_count/batches_per_block;
  413. dim3 threads(warp_size, warps_per_block, 1);
  414. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  415. switch (log2_elements) {
  416. case 0: // 1
  417. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
  418. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  419. break;
  420. case 1: // 2
  421. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
  422. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  423. break;
  424. case 2: // 4
  425. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
  426. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  427. break;
  428. case 3: // 8
  429. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
  430. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  431. break;
  432. case 4: // 16
  433. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
  434. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  435. break;
  436. case 5: // 32
  437. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
  438. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  439. break;
  440. case 6: // 64
  441. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
  442. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  443. break;
  444. case 7: // 128
  445. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
  446. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  447. break;
  448. case 8: // 256
  449. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
  450. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  451. break;
  452. case 9: // 512
  453. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
  454. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  455. break;
  456. case 10: // 1024
  457. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
  458. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  459. break;
  460. case 11: // 2048
  461. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
  462. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  463. break;
  464. case 12: // 4096
  465. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
  466. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  467. break;
  468. case 13: // 8192
  469. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
  470. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  471. break;
  472. default:
  473. break;
  474. }
  475. }
  476. }