flash.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda.h>
  6. #include <vector>
  7. #ifdef OLD_GENERATOR_PATH
  8. #include <ATen/CUDAGeneratorImpl.h>
  9. #else
  10. #include <ATen/cuda/CUDAGeneratorImpl.h>
  11. #endif
  12. #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
  13. constexpr int TOTAL_DIM = 0;
  14. constexpr int H_DIM = 1;
  15. constexpr int D_DIM = 2;
  16. ////////////////////////////////////////////////////////////////////////////////////////////////////
  17. struct Qkv_params {
  18. using index_t = int64_t;
  19. // The QKV matrices.
  20. void* __restrict__ q_ptr;
  21. void* __restrict__ k_ptr;
  22. void* __restrict__ v_ptr;
  23. // The stride between rows of the Q, K and V matrices.
  24. index_t q_batch_stride;
  25. index_t k_batch_stride;
  26. index_t v_batch_stride;
  27. index_t q_row_stride;
  28. index_t k_row_stride;
  29. index_t v_row_stride;
  30. index_t q_head_stride;
  31. index_t k_head_stride;
  32. index_t v_head_stride;
  33. // The number of heads.
  34. int h, h_k;
  35. // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k
  36. // could be different from nheads (query).
  37. int h_h_k_ratio; // precompute h / h_k,
  38. };
  39. ////////////////////////////////////////////////////////////////////////////////////////////////////
  40. struct Flash_fwd_params : public Qkv_params {
  41. // The O matrix (output).
  42. void* __restrict__ o_ptr;
  43. void* __restrict__ oaccum_ptr;
  44. // The stride between rows of O.
  45. index_t o_batch_stride;
  46. index_t o_row_stride;
  47. index_t o_head_stride;
  48. // The pointer to the P matrix.
  49. void* __restrict__ p_ptr;
  50. // The pointer to the softmax sum.
  51. void* __restrict__ softmax_lse_ptr;
  52. void* __restrict__ softmax_lseaccum_ptr;
  53. // The dimensions.
  54. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded,
  55. d_rounded, rotary_dim, total_q;
  56. // The scaling factors for the kernel.
  57. float scale_softmax;
  58. float scale_softmax_log2;
  59. // array of length b+1 holding starting offset of each sequence.
  60. int* __restrict__ cu_seqlens_q;
  61. int* __restrict__ cu_seqlens_k;
  62. // If provided, the actual length of each k sequence.
  63. int* __restrict__ seqused_k;
  64. int* __restrict__ blockmask;
  65. // The K_new and V_new matrices.
  66. void* __restrict__ knew_ptr;
  67. void* __restrict__ vnew_ptr;
  68. // The stride between rows of the Q, K and V matrices.
  69. index_t knew_batch_stride;
  70. index_t vnew_batch_stride;
  71. index_t knew_row_stride;
  72. index_t vnew_row_stride;
  73. index_t knew_head_stride;
  74. index_t vnew_head_stride;
  75. // The cos and sin matrices for rotary embedding.
  76. void* __restrict__ rotary_cos_ptr;
  77. void* __restrict__ rotary_sin_ptr;
  78. // The indices to index into the KV cache.
  79. int* __restrict__ cache_batch_idx;
  80. // Paged KV cache
  81. int* __restrict__ block_table;
  82. index_t block_table_batch_stride;
  83. int page_block_size;
  84. // The dropout probability (probability of keeping an activation).
  85. float p_dropout;
  86. // uint32_t p_dropout_in_uint;
  87. // uint16_t p_dropout_in_uint16_t;
  88. uint8_t p_dropout_in_uint8_t;
  89. // Scale factor of 1 / (1 - p_dropout).
  90. float rp_dropout;
  91. float scale_softmax_rp_dropout;
  92. // Local window size
  93. int window_size_left, window_size_right;
  94. float softcap;
  95. // Random state.
  96. at::PhiloxCudaState philox_args;
  97. // Pointer to the RNG seed (idx 0) and offset (idx 1).
  98. uint64_t* rng_state;
  99. bool is_bf16;
  100. bool is_causal;
  101. // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] -
  102. // cu_seqlens_k[bidb]. Otherwise it's cu_seqlens_k[bidb], i.e., we use
  103. // cu_seqlens_k to store the sequence lengths of K.
  104. bool is_seqlens_k_cumulative;
  105. bool is_rotary_interleaved;
  106. int num_splits; // For split-KV version
  107. void* __restrict__ alibi_slopes_ptr;
  108. index_t alibi_slopes_batch_stride;
  109. bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q]
  110. // format instead of [b, nheads, seqlen_q].
  111. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv
  112. // ngroups), d) to (b, ngroups, nheads_kv, d).
  113. };
  114. ////////////////////////////////////////////////////////////////////////////////////////////////////
  115. template <typename T, int Headdim, bool Is_causal>
  116. void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
  117. template <typename T, int Headdim, bool Is_causal>
  118. void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params,
  119. cudaStream_t stream);