ft_attention.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. #include <torch/extension.h>
  2. #include "ATen/cuda/CUDAContext.h"
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "decoder_masked_multihead_attention.h"
  5. #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
  6. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  7. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  8. #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
  9. if (TYPE == at::ScalarType::Half) { \
  10. using scalar_t = at::Half; \
  11. __VA_ARGS__(); \
  12. } else if (TYPE == at::ScalarType::BFloat16) { \
  13. using scalar_t = at::BFloat16; \
  14. __VA_ARGS__(); \
  15. } else if (TYPE == at::ScalarType::Float) { \
  16. using scalar_t = float; \
  17. __VA_ARGS__(); \
  18. } else { \
  19. AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
  20. }
  21. template<typename T>
  22. void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
  23. const cudaStream_t& stream);
  24. template<typename T>
  25. void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
  26. const cudaStream_t& stream);
  27. template<typename T>
  28. struct SATypeConverter {
  29. using Type = T;
  30. };
  31. template<>
  32. struct SATypeConverter<at::Half> {
  33. using Type = uint16_t;
  34. };
  35. template<>
  36. struct SATypeConverter<at::BFloat16> {
  37. using Type = __nv_bfloat16;
  38. };
  39. template <typename T>
  40. void set_params(Masked_multihead_attention_params<T> &params,
  41. const size_t batch_size,
  42. const size_t nheads,
  43. const size_t nheads_kv,
  44. const size_t memory_max_seqlen,
  45. const size_t headdim,
  46. const int timestep,
  47. const int rotary_embedding_dim,
  48. const float rotary_base,
  49. const bool neox_rotary_style,
  50. const int q_batch_stride,
  51. const int k_batch_stride,
  52. const int v_batch_stride,
  53. const int nnz_heads,
  54. T *q_ptr,
  55. T *k_ptr,
  56. T *v_ptr,
  57. T *k_cache_ptr,
  58. T *v_cache_ptr,
  59. int *length_per_sample,
  60. T *rotary_cos,
  61. T *rotary_sin,
  62. T *out_ptr,
  63. int *nnz_head_idx) {
  64. // Reset the parameters
  65. memset(&params, 0, sizeof(params));
  66. params.q = q_ptr;
  67. params.k = k_ptr;
  68. params.v = v_ptr;
  69. params.q_bias = nullptr;
  70. params.k_bias = nullptr;
  71. params.v_bias = nullptr;
  72. params.k_cache = k_cache_ptr;
  73. params.v_cache = v_cache_ptr;
  74. params.out = out_ptr;
  75. params.cache_indir = nullptr;
  76. params.stride_q = q_batch_stride;
  77. params.stride_k = k_batch_stride;
  78. params.stride_v = v_batch_stride;
  79. params.batch_size = batch_size;
  80. params.beam_width = 1;
  81. params.memory_max_len = memory_max_seqlen;
  82. params.num_heads = nheads;
  83. params.num_heads_kv = nheads_kv;
  84. params.num_heads_q_kv_ratio = nheads / nheads_kv;
  85. params.nnz_heads = nnz_heads;
  86. params.hidden_size_per_head = headdim;
  87. params.rotary_embedding_dim = rotary_embedding_dim;
  88. params.rotary_base = rotary_base;
  89. params.neox_rotary_style = neox_rotary_style;
  90. params.timestep = timestep;
  91. params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
  92. params.total_padding_tokens = nullptr;
  93. params.masked_tokens = nullptr;
  94. params.prefix_prompt_lengths = nullptr;
  95. params.max_prefix_prompt_length = 0;
  96. params.relative_attention_bias = nullptr;
  97. params.relative_attention_bias_stride = 0;
  98. params.cross_attention_out = nullptr;
  99. params.max_decoder_seq_len = 0;
  100. params.is_return_cross_attentions = false;
  101. params.finished = nullptr;
  102. params.memory_length_per_sample = nullptr;
  103. params.length_per_sample = length_per_sample;
  104. params.rotary_cos = rotary_cos;
  105. params.rotary_sin = rotary_sin;
  106. params.nnz_head_idx = nnz_head_idx;
  107. }
  108. torch::Tensor single_query_attention(const torch::Tensor q,
  109. const torch::Tensor k,
  110. const torch::Tensor v,
  111. torch::Tensor k_cache,
  112. torch::Tensor v_cache,
  113. std::optional<const torch::Tensor> length_per_sample_,
  114. std::optional<const torch::Tensor> rotary_cos_,
  115. std::optional<const torch::Tensor> rotary_sin_,
  116. std::optional<const torch::Tensor> nnz_head_idx_,
  117. const int timestep,
  118. int rotary_embedding_dim = 0,
  119. const float rotary_base = 10000.0f,
  120. const bool neox_rotary_style=true) {
  121. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
  122. int batch_size = v_cache.size(0);
  123. int nheads = q.size(1);
  124. int nheads_kv = v_cache.size(1);
  125. int memory_max_seqlen = v_cache.size(2);
  126. int headdim = v_cache.size(3);
  127. auto input_type = q.scalar_type();
  128. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  129. CHECK_SHAPE(q, batch_size, nheads, headdim);
  130. CHECK_SHAPE(k, batch_size, nheads_kv, headdim);
  131. CHECK_SHAPE(v, batch_size, nheads_kv, headdim);
  132. CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim);
  133. // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
  134. int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
  135. CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize);
  136. TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
  137. TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
  138. TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
  139. CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
  140. TORCH_CHECK(q.scalar_type() == input_type);
  141. TORCH_CHECK(k.scalar_type() == input_type);
  142. TORCH_CHECK(v.scalar_type() == input_type);
  143. TORCH_CHECK(k_cache.scalar_type() == input_type);
  144. TORCH_CHECK(v_cache.scalar_type() == input_type);
  145. if (length_per_sample_.has_value()) {
  146. auto length_per_sample = length_per_sample_.value();
  147. CHECK_DEVICE(length_per_sample);
  148. CHECK_SHAPE(length_per_sample, batch_size);
  149. CHECK_CONTIGUOUS(length_per_sample);
  150. TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
  151. }
  152. if (rotary_cos_.has_value()) {
  153. auto rotary_cos = rotary_cos_.value();
  154. CHECK_DEVICE(rotary_cos);
  155. rotary_embedding_dim = rotary_cos.size(-1) * 2;
  156. CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
  157. CHECK_CONTIGUOUS(rotary_cos);
  158. TORCH_CHECK(rotary_cos.scalar_type() == input_type);
  159. TORCH_CHECK(rotary_sin_.has_value());
  160. auto rotary_sin = rotary_sin_.value();
  161. CHECK_DEVICE(rotary_sin);
  162. CHECK_SHAPE(rotary_sin, batch_size, rotary_embedding_dim / 2);
  163. CHECK_CONTIGUOUS(rotary_sin);
  164. TORCH_CHECK(rotary_sin.scalar_type() == input_type);
  165. }
  166. if (nnz_head_idx_.has_value()) {
  167. auto nnz_head_idx = nnz_head_idx_.value();
  168. CHECK_DEVICE(nnz_head_idx);
  169. int nnz_heads = nnz_head_idx.size(0);
  170. CHECK_SHAPE(nnz_head_idx, nnz_heads);
  171. CHECK_CONTIGUOUS(nnz_head_idx);
  172. TORCH_CHECK(nnz_head_idx.dtype() == torch::kInt32);
  173. }
  174. // Otherwise the kernel will be launched from cuda:0 device
  175. at::cuda::CUDAGuard device_guard{q.device()};
  176. torch::Tensor out = torch::empty_like(q);
  177. DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
  178. using DataType = typename SATypeConverter<scalar_t>::Type;
  179. Masked_multihead_attention_params<DataType> params;
  180. set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, timestep,
  181. rotary_embedding_dim, rotary_base, neox_rotary_style,
  182. q.stride(0), k.stride(0), v.stride(0),
  183. nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0,
  184. reinterpret_cast<DataType*>(q.data_ptr()),
  185. reinterpret_cast<DataType*>(k.data_ptr()),
  186. reinterpret_cast<DataType*>(v.data_ptr()),
  187. reinterpret_cast<DataType*>(k_cache.data_ptr()),
  188. reinterpret_cast<DataType*>(v_cache.data_ptr()),
  189. length_per_sample_.has_value()
  190. ? length_per_sample_.value().data_ptr<int>() : nullptr,
  191. rotary_cos_.has_value()
  192. ? reinterpret_cast<DataType*>(rotary_cos_.value().data_ptr()) : nullptr,
  193. rotary_sin_.has_value()
  194. ? reinterpret_cast<DataType*>(rotary_sin_.value().data_ptr()) : nullptr,
  195. reinterpret_cast<DataType*>(out.data_ptr()),
  196. nnz_head_idx_.has_value() ? nnz_head_idx_.value().data_ptr<int>() : nullptr
  197. );
  198. auto stream = at::cuda::getCurrentCUDAStream();
  199. masked_multihead_attention(params, stream);
  200. });
  201. return out;
  202. }
  203. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  204. m.def("single_query_attention", &single_query_attention, "Attention with a single query",
  205. py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
  206. py::arg("length_per_sample_"), py::arg("rotary_cos_"),
  207. py::arg("rotary_sin_"), py::arg("nnz_head_idx_"),
  208. py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
  209. py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
  210. }