flash.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda.h>
  6. #include <vector>
  7. #ifdef OLD_GENERATOR_PATH
  8. #include <ATen/CUDAGeneratorImpl.h>
  9. #else
  10. #include <ATen/cuda/CUDAGeneratorImpl.h>
  11. #endif
  12. #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
  13. constexpr int TOTAL_DIM = 0;
  14. constexpr int H_DIM = 1;
  15. constexpr int D_DIM = 2;
  16. ////////////////////////////////////////////////////////////////////////////////////////////////////
  17. struct Qkv_params {
  18. using index_t = uint32_t;
  19. // The QKV matrices.
  20. void *__restrict__ q_ptr;
  21. void *__restrict__ k_ptr;
  22. void *__restrict__ v_ptr;
  23. // The stride between rows of the Q, K and V matrices.
  24. index_t q_batch_stride;
  25. index_t k_batch_stride;
  26. index_t v_batch_stride;
  27. index_t q_row_stride;
  28. index_t k_row_stride;
  29. index_t v_row_stride;
  30. index_t q_head_stride;
  31. index_t k_head_stride;
  32. index_t v_head_stride;
  33. // The number of heads.
  34. int h, h_k;
  35. // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
  36. // different from nheads (query).
  37. int h_h_k_ratio; // precompute h / h_k,
  38. };
  39. ////////////////////////////////////////////////////////////////////////////////////////////////////
  40. struct Flash_fwd_params : public Qkv_params {
  41. // The O matrix (output).
  42. void * __restrict__ o_ptr;
  43. void * __restrict__ oaccum_ptr;
  44. // The stride between rows of O.
  45. index_t o_batch_stride;
  46. index_t o_row_stride;
  47. index_t o_head_stride;
  48. // The pointer to the P matrix.
  49. void * __restrict__ p_ptr;
  50. // The pointer to the softmax sum.
  51. void * __restrict__ softmax_lse_ptr;
  52. void * __restrict__ softmax_lseaccum_ptr;
  53. // The dimensions.
  54. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
  55. // The scaling factors for the kernel.
  56. float scale_softmax;
  57. float scale_softmax_log2;
  58. // array of length b+1 holding starting offset of each sequence.
  59. int * __restrict__ cu_seqlens_q;
  60. int * __restrict__ cu_seqlens_k;
  61. int *__restrict__ blockmask;
  62. // The K_new and V_new matrices.
  63. void * __restrict__ knew_ptr;
  64. void * __restrict__ vnew_ptr;
  65. // The stride between rows of the Q, K and V matrices.
  66. index_t knew_batch_stride;
  67. index_t vnew_batch_stride;
  68. index_t knew_row_stride;
  69. index_t vnew_row_stride;
  70. index_t knew_head_stride;
  71. index_t vnew_head_stride;
  72. // The cos and sin matrices for rotary embedding.
  73. void * __restrict__ rotary_cos_ptr;
  74. void * __restrict__ rotary_sin_ptr;
  75. // The dropout probability (probability of keeping an activation).
  76. float p_dropout;
  77. // uint32_t p_dropout_in_uint;
  78. // uint16_t p_dropout_in_uint16_t;
  79. uint8_t p_dropout_in_uint8_t;
  80. // Scale factor of 1 / (1 - p_dropout).
  81. float rp_dropout;
  82. float scale_softmax_rp_dropout;
  83. // Random state.
  84. at::PhiloxCudaState philox_args;
  85. // Pointer to the RNG seed (idx 0) and offset (idx 1).
  86. uint64_t * rng_state;
  87. bool is_bf16;
  88. bool is_causal;
  89. // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
  90. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
  91. bool is_seqlens_k_cumulative;
  92. bool is_rotary_interleaved;
  93. int num_splits; // For split-KV version
  94. };
  95. ////////////////////////////////////////////////////////////////////////////////////////////////////
  96. struct Flash_bwd_params : public Flash_fwd_params {
  97. // The dO and dQKV matrices.
  98. void *__restrict__ do_ptr;
  99. void *__restrict__ dq_ptr;
  100. void *__restrict__ dk_ptr;
  101. void *__restrict__ dv_ptr;
  102. // To accumulate dQ
  103. void *__restrict__ dq_accum_ptr;
  104. void *__restrict__ dk_accum_ptr;
  105. void *__restrict__ dv_accum_ptr;
  106. // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
  107. // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
  108. // dv_accum_ptr;
  109. // The stride between rows of the dO, dQ, dK and dV matrices.
  110. // TD [2022-04-16]: We're using 32-bit indexing to save registers.
  111. // The code probably won't work for arrays larger than 2GB.
  112. index_t do_batch_stride;
  113. index_t do_row_stride;
  114. index_t do_head_stride;
  115. index_t dq_batch_stride;
  116. index_t dk_batch_stride;
  117. index_t dv_batch_stride;
  118. index_t dq_row_stride;
  119. index_t dk_row_stride;
  120. index_t dv_row_stride;
  121. index_t dq_head_stride;
  122. index_t dk_head_stride;
  123. index_t dv_head_stride;
  124. // The pointer to the softmax d sum.
  125. void *__restrict__ dsoftmax_sum;
  126. };
  127. ////////////////////////////////////////////////////////////////////////////////////////////////////
  128. template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
  129. template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
  130. template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);