flash.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda.h>
  6. #include <vector>
  7. ////////////////////////////////////////////////////////////////////////////////////////////////////
  8. struct Qkv_params {
  9. using index_t = int64_t;
  10. // The QKV matrices.
  11. void *__restrict__ q_ptr;
  12. void *__restrict__ k_ptr;
  13. void *__restrict__ v_ptr;
  14. // The stride between rows of the Q, K and V matrices.
  15. index_t q_batch_stride;
  16. index_t k_batch_stride;
  17. index_t v_batch_stride;
  18. index_t q_row_stride;
  19. index_t k_row_stride;
  20. index_t v_row_stride;
  21. index_t q_head_stride;
  22. index_t k_head_stride;
  23. index_t v_head_stride;
  24. index_t v_dim_stride;
  25. // The number of heads.
  26. int h, h_k;
  27. };
  28. ////////////////////////////////////////////////////////////////////////////////////////////////////
  29. struct Flash_fwd_params : public Qkv_params {
  30. using index_t = int64_t;
  31. // The O matrix (output).
  32. void * __restrict__ o_ptr;
  33. void * __restrict__ oaccum_ptr;
  34. // The stride between rows of O.
  35. index_t o_batch_stride;
  36. index_t o_row_stride;
  37. index_t o_head_stride;
  38. // The pointer to the softmax sum.
  39. void * __restrict__ softmax_lse_ptr;
  40. void * __restrict__ softmax_lseaccum_ptr;
  41. // For FP8 scaling
  42. float * __restrict__ q_descale_ptr;
  43. float * __restrict__ k_descale_ptr;
  44. float * __restrict__ v_descale_ptr;
  45. index_t q_descale_batch_stride;
  46. index_t q_descale_head_stride;
  47. index_t k_descale_batch_stride;
  48. index_t k_descale_head_stride;
  49. index_t v_descale_batch_stride;
  50. index_t v_descale_head_stride;
  51. // The dimensions.
  52. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
  53. int total_q, total_k, total_knew;
  54. int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
  55. // The scaling factors for the kernel.
  56. float scale_softmax;
  57. float softcap;
  58. // array of length b+1 holding starting offset of each sequence.
  59. int * __restrict__ cu_seqlens_q;
  60. int * __restrict__ cu_seqlens_k;
  61. int * __restrict__ cu_seqlens_knew;
  62. int * __restrict__ leftpad_k;
  63. // If provided, the actual length of each q/k sequence.
  64. int *__restrict__ seqused_q;
  65. int *__restrict__ seqused_k;
  66. // The stride between rows of Oaccum.
  67. index_t oaccum_split_stride;
  68. index_t oaccum_batch_stride;
  69. index_t oaccum_row_stride;
  70. index_t oaccum_head_stride;
  71. // The stride between rows of LSEaccum.
  72. index_t lseaccum_split_stride;
  73. index_t lseaccum_batch_stride;
  74. index_t lseaccum_head_stride;
  75. // The K_new and V_new matrices.
  76. void * __restrict__ knew_ptr;
  77. void * __restrict__ vnew_ptr;
  78. // The stride between rows of the Q, K and V matrices.
  79. index_t knew_batch_stride;
  80. index_t vnew_batch_stride;
  81. index_t knew_row_stride;
  82. index_t vnew_row_stride;
  83. index_t knew_head_stride;
  84. index_t vnew_head_stride;
  85. // The cos and sin matrices for rotary embedding.
  86. void * __restrict__ rotary_cos_ptr;
  87. void * __restrict__ rotary_sin_ptr;
  88. // The indices to index into the KV cache.
  89. int * __restrict__ kv_batch_idx;
  90. // Paged KV cache
  91. int * __restrict__ page_table;
  92. index_t page_table_batch_stride;
  93. int page_size;
  94. int num_pages;
  95. // The dropout probability (probability of keeping an activation).
  96. float p_dropout;
  97. // uint32_t p_dropout_in_uint;
  98. // uint16_t p_dropout_in_uint16_t;
  99. uint8_t p_dropout_in_uint8_t;
  100. // Scale factor of 1 / (1 - p_dropout).
  101. float rp_dropout;
  102. // Local window size
  103. int window_size_left, window_size_right;
  104. int sink_token_length;
  105. // Pointer to the RNG seed (idx 0) and offset (idx 1).
  106. uint64_t * rng_state;
  107. bool is_bf16;
  108. bool is_fp32;
  109. bool is_e4m3;
  110. bool is_causal;
  111. bool is_local;
  112. bool is_rotary_interleaved;
  113. int num_splits; // For split-KV version
  114. bool pack_gqa;
  115. int * __restrict__ tile_count_semaphore;
  116. };
  117. ////////////////////////////////////////////////////////////////////////////////////////////////////
  118. struct Flash_bwd_params : public Flash_fwd_params {
  119. using index_t = int64_t;
  120. // The dO and dQKV matrices.
  121. void *__restrict__ do_ptr;
  122. void *__restrict__ dq_ptr;
  123. void *__restrict__ dk_ptr;
  124. void *__restrict__ dv_ptr;
  125. // To accumulate dQ
  126. void *__restrict__ dq_accum_ptr;
  127. void *__restrict__ dk_accum_ptr;
  128. void *__restrict__ dv_accum_ptr;
  129. // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
  130. // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
  131. // dv_accum_ptr;
  132. // The stride between rows of the dO, dQ, dK and dV matrices.
  133. index_t do_batch_stride;
  134. index_t do_row_stride;
  135. index_t do_head_stride;
  136. index_t dq_batch_stride;
  137. index_t dk_batch_stride;
  138. index_t dv_batch_stride;
  139. index_t dq_row_stride;
  140. index_t dk_row_stride;
  141. index_t dv_row_stride;
  142. index_t dq_head_stride;
  143. index_t dk_head_stride;
  144. index_t dv_head_stride;
  145. // The pointer to the softmax d sum.
  146. void *__restrict__ dsoftmax_sum;
  147. void *__restrict__ softmax_lse_log2_ptr;
  148. int *__restrict__ dq_semaphore;
  149. int *__restrict__ dk_semaphore;
  150. int *__restrict__ dv_semaphore;
  151. bool deterministic;
  152. index_t dq_accum_split_stride;
  153. };
  154. ////////////////////////////////////////////////////////////////////////////////////////////////////
  155. template<typename T, int Headdim, bool Split, bool PagedKV> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
  156. template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
  157. template<typename T, typename Tpartial, int Headdim> void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream);