flash_api.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. std::vector<at::Tensor>
  6. mha_fwd(at::Tensor &q,
  7. const at::Tensor &k,
  8. const at::Tensor &v,
  9. c10::optional<at::Tensor> &out_,
  10. c10::optional<at::Tensor> &alibi_slopes_,
  11. const float p_dropout,
  12. const float softmax_scale,
  13. bool is_causal,
  14. int window_size_left,
  15. int window_size_right,
  16. const float softcap,
  17. const bool return_softmax,
  18. c10::optional<at::Generator> gen_);
  19. std::vector<at::Tensor>
  20. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  21. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  22. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  23. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  24. const at::Tensor &cu_seqlens_q, // b+1
  25. const at::Tensor &cu_seqlens_k, // b+1
  26. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  27. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  28. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  29. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  30. int max_seqlen_q,
  31. const int max_seqlen_k,
  32. const float p_dropout,
  33. const float softmax_scale,
  34. const bool zero_tensors,
  35. bool is_causal,
  36. int window_size_left,
  37. int window_size_right,
  38. const float softcap,
  39. const bool return_softmax,
  40. c10::optional<at::Generator> gen_);
  41. std::vector<at::Tensor>
  42. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  43. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  44. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  45. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  46. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  47. const at::Tensor &softmax_lse, // b x h x seqlen_q
  48. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  49. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  50. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  51. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  52. const float p_dropout, // probability to drop
  53. const float softmax_scale,
  54. const bool is_causal,
  55. int window_size_left,
  56. int window_size_right,
  57. const float softcap,
  58. const bool deterministic,
  59. c10::optional<at::Generator> gen_,
  60. c10::optional<at::Tensor> &rng_state);
  61. std::vector<at::Tensor>
  62. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
  63. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  64. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  65. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  66. const at::Tensor &out, // total_q x num_heads x head_size
  67. const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
  68. c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  69. c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  70. c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  71. const at::Tensor &cu_seqlens_q, // b+1
  72. const at::Tensor &cu_seqlens_k, // b+1
  73. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  74. const int max_seqlen_q,
  75. const int max_seqlen_k, // max sequence length to choose the kernel
  76. const float p_dropout, // probability to drop
  77. const float softmax_scale,
  78. const bool zero_tensors,
  79. const bool is_causal,
  80. int window_size_left,
  81. int window_size_right,
  82. const float softcap,
  83. const bool deterministic,
  84. c10::optional<at::Generator> gen_,
  85. c10::optional<at::Tensor> &rng_state);
  86. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  87. {
  88. m.doc() = "FlashAttention";
  89. m.def("fwd", &mha_fwd, "Forward pass");
  90. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  91. m.def("bwd", &mha_bwd, "Backward pass");
  92. m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
  93. }