flash.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda.h>
  6. #include <vector>
  7. #include "cutlass/fast_math.h" // For cutlass::FastDivmod
  8. ////////////////////////////////////////////////////////////////////////////////////////////////////
  9. struct Qkv_params {
  10. using index_t = int64_t;
  11. // The QKV matrices.
  12. void *__restrict__ q_ptr;
  13. void *__restrict__ k_ptr;
  14. void *__restrict__ v_ptr;
  15. // The stride between rows of the Q, K and V matrices.
  16. index_t q_batch_stride;
  17. index_t k_batch_stride;
  18. index_t v_batch_stride;
  19. index_t q_row_stride;
  20. index_t k_row_stride;
  21. index_t v_row_stride;
  22. index_t q_head_stride;
  23. index_t k_head_stride;
  24. index_t v_head_stride;
  25. // The number of heads.
  26. int h, h_k;
  27. // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
  28. // different from nheads (query).
  29. int h_h_k_ratio; // precompute h / h_k,
  30. };
  31. ////////////////////////////////////////////////////////////////////////////////////////////////////
  32. struct Flash_fwd_params : public Qkv_params {
  33. // The O matrix (output).
  34. void * __restrict__ o_ptr;
  35. void * __restrict__ oaccum_ptr;
  36. // The stride between rows of O.
  37. index_t o_batch_stride;
  38. index_t o_row_stride;
  39. index_t o_head_stride;
  40. // The stride between rows of Oaccum.
  41. index_t oaccum_batch_stride;
  42. index_t oaccum_row_stride;
  43. index_t oaccum_head_stride;
  44. index_t oaccum_split_stride;
  45. // The pointer to the P matrix.
  46. void * __restrict__ p_ptr;
  47. // The pointer to the softmax sum.
  48. void * __restrict__ softmax_lse_ptr;
  49. void * __restrict__ softmax_lseaccum_ptr;
  50. // The dimensions.
  51. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q, total_k;
  52. int b_k;
  53. // The scaling factors for the kernel.
  54. float scale_softmax;
  55. float scale_softmax_log2;
  56. uint32_t scale_softmax_log2_half2;
  57. // array of length b+1 holding starting offset of each sequence.
  58. int * __restrict__ cu_seqlens_q;
  59. int * __restrict__ cu_seqlens_k;
  60. // If provided, the actual length of each q / o sequence.
  61. int * __restrict__ seqused_q;
  62. // If provided, the actual length of each k / v sequence.
  63. int * __restrict__ seqused_k;
  64. int *__restrict__ blockmask;
  65. // The K_new and V_new matrices.
  66. void * __restrict__ knew_ptr;
  67. void * __restrict__ vnew_ptr;
  68. // The stride between rows of the Q, K and V matrices.
  69. index_t knew_batch_stride;
  70. index_t vnew_batch_stride;
  71. index_t knew_row_stride;
  72. index_t vnew_row_stride;
  73. index_t knew_head_stride;
  74. index_t vnew_head_stride;
  75. // The cos and sin matrices for rotary embedding.
  76. void * __restrict__ rotary_cos_ptr;
  77. void * __restrict__ rotary_sin_ptr;
  78. // The indices to index into the KV cache.
  79. int * __restrict__ cache_batch_idx;
  80. // Paged KV cache
  81. int * __restrict__ block_table;
  82. index_t block_table_batch_stride;
  83. int page_block_size;
  84. // The dropout probability (probability of keeping an activation).
  85. float p_dropout;
  86. // uint32_t p_dropout_in_uint;
  87. // uint16_t p_dropout_in_uint16_t;
  88. uint8_t p_dropout_in_uint8_t;
  89. // Scale factor of 1 / (1 - p_dropout).
  90. float rp_dropout;
  91. float scale_softmax_rp_dropout;
  92. // Local window size
  93. int window_size_left, window_size_right;
  94. // Pointer to the RNG seed (idx 0) and offset (idx 1).
  95. uint64_t * rng_state;
  96. bool is_bf16;
  97. bool is_e4m3;
  98. bool is_causal;
  99. bool is_local;
  100. bool is_kv_cache;
  101. bool use_gqa_packing;
  102. bool is_rotary_interleaved;
  103. int num_splits; // For split-KV version
  104. void * __restrict__ alibi_slopes_ptr;
  105. index_t alibi_slopes_batch_stride;
  106. bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
  107. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
  108. int * __restrict__ tile_count_semaphore;
  109. float * __restrict__ descale_q_ptr;
  110. float * __restrict__ descale_k_ptr;
  111. float * __restrict__ descale_v_ptr;
  112. };
  113. ////////////////////////////////////////////////////////////////////////////////////////////////////
  114. struct Flash_bwd_params : public Flash_fwd_params {
  115. // The dO and dQKV matrices.
  116. void *__restrict__ do_ptr;
  117. void *__restrict__ dq_ptr;
  118. void *__restrict__ dk_ptr;
  119. void *__restrict__ dv_ptr;
  120. // To accumulate dQ
  121. void *__restrict__ dq_accum_ptr;
  122. void *__restrict__ dk_accum_ptr;
  123. void *__restrict__ dv_accum_ptr;
  124. // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
  125. // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
  126. // dv_accum_ptr;
  127. // The stride between rows of the dO, dQ, dK and dV matrices.
  128. // TD [2022-04-16]: We're using 32-bit indexing to save registers.
  129. // The code probably won't work for arrays larger than 2GB.
  130. index_t do_batch_stride;
  131. index_t do_row_stride;
  132. index_t do_head_stride;
  133. index_t dq_batch_stride;
  134. index_t dk_batch_stride;
  135. index_t dv_batch_stride;
  136. index_t dq_row_stride;
  137. index_t dk_row_stride;
  138. index_t dv_row_stride;
  139. index_t dq_head_stride;
  140. index_t dk_head_stride;
  141. index_t dv_head_stride;
  142. // The pointer to the softmax d sum.
  143. void *__restrict__ dsoftmax_sum;
  144. void *__restrict__ softmax_lse_log2_ptr;
  145. int *__restrict__ dq_semaphore;
  146. bool deterministic;
  147. index_t dq_accum_split_stride;
  148. };
  149. ////////////////////////////////////////////////////////////////////////////////////////////////////
  150. template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
  151. template<typename T, int Headdim, int kBlockH> void run_mha_fwd_gqa_(Flash_fwd_params &params, cudaStream_t stream);
  152. template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);