flash_api.cpp 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. std::vector<at::Tensor>
  6. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
  7. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
  8. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
  9. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
  10. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  11. const float p_dropout,
  12. const float softmax_scale,
  13. bool is_causal,
  14. int window_size_left,
  15. int window_size_right,
  16. const float softcap,
  17. const bool return_softmax,
  18. c10::optional<at::Generator> gen_);
  19. std::vector<at::Tensor>
  20. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  21. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  22. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  23. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  24. const at::Tensor &cu_seqlens_q, // b+1
  25. const at::Tensor &cu_seqlens_k, // b+1
  26. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  27. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  28. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  29. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  30. int max_seqlen_q,
  31. const int max_seqlen_k,
  32. const float p_dropout,
  33. const float softmax_scale,
  34. const bool zero_tensors,
  35. bool is_causal,
  36. int window_size_left,
  37. int window_size_right,
  38. const float softcap,
  39. const bool return_softmax,
  40. c10::optional<at::Generator> gen_);
  41. std::vector<at::Tensor>
  42. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
  43. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  44. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  45. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  46. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  47. const at::Tensor &softmax_lse, // b x h x seqlen_q
  48. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  49. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  50. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  51. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  52. const float p_dropout, // probability to drop
  53. const float softmax_scale,
  54. const bool is_causal,
  55. int window_size_left,
  56. int window_size_right,
  57. const float softcap,
  58. const bool deterministic,
  59. c10::optional<at::Generator> gen_,
  60. c10::optional<at::Tensor> &rng_state);
  61. std::vector<at::Tensor>
  62. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
  63. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  64. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  65. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  66. const at::Tensor &out, // total_q x num_heads x head_size
  67. const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
  68. c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  69. c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  70. c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  71. const at::Tensor &cu_seqlens_q, // b+1
  72. const at::Tensor &cu_seqlens_k, // b+1
  73. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  74. const int max_seqlen_q,
  75. const int max_seqlen_k, // max sequence length to choose the kernel
  76. const float p_dropout, // probability to drop
  77. const float softmax_scale,
  78. const bool zero_tensors,
  79. const bool is_causal,
  80. int window_size_left,
  81. int window_size_right,
  82. const float softcap,
  83. const bool deterministic,
  84. c10::optional<at::Generator> gen_,
  85. c10::optional<at::Tensor> &rng_state);
  86. std::vector<at::Tensor>
  87. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  88. const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  89. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  90. c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  91. c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  92. c10::optional<const at::Tensor> &seqlens_k_, // batch_size
  93. c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  94. c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  95. c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
  96. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  97. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  98. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  99. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  100. const float softmax_scale,
  101. bool is_causal,
  102. int window_size_left,
  103. int window_size_right,
  104. const float softcap,
  105. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  106. int num_splits);
  107. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  108. {
  109. m.doc() = "FlashAttention";
  110. m.def("fwd", &mha_fwd, "Forward pass");
  111. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  112. m.def("bwd", &mha_bwd, "Backward pass");
  113. m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
  114. m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
  115. }