ft_attention.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #include <torch/extension.h>
  2. #include "ATen/cuda/CUDAContext.h"
  3. #include "decoder_masked_multihead_attention.h"
  4. #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
  5. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  6. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  7. #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
  8. if (TYPE == at::ScalarType::Half) { \
  9. using scalar_t = at::Half; \
  10. __VA_ARGS__(); \
  11. } else if (TYPE == at::ScalarType::BFloat16) { \
  12. using scalar_t = at::BFloat16; \
  13. __VA_ARGS__(); \
  14. } else if (TYPE == at::ScalarType::Float) { \
  15. using scalar_t = float; \
  16. __VA_ARGS__(); \
  17. } else { \
  18. AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
  19. }
  20. // #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
  21. // if (TYPE == at::ScalarType::Half) { \
  22. // using scalar_t = at::Half; \
  23. // __VA_ARGS__(); \
  24. // } else if (TYPE == at::ScalarType::Float) { \
  25. // using scalar_t = float; \
  26. // __VA_ARGS__(); \
  27. // } else { \
  28. // AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
  29. // }
  30. template<typename T>
  31. void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
  32. const cudaStream_t& stream);
  33. template<typename T>
  34. void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
  35. const cudaStream_t& stream);
  36. template<typename T>
  37. struct SATypeConverter {
  38. using Type = T;
  39. };
  40. template<>
  41. struct SATypeConverter<at::Half> {
  42. using Type = uint16_t;
  43. };
  44. template<>
  45. struct SATypeConverter<at::BFloat16> {
  46. using Type = __nv_bfloat16;
  47. };
  48. template <typename T>
  49. void set_params(Masked_multihead_attention_params<T> &params,
  50. const size_t batch_size,
  51. const size_t nheads,
  52. const size_t memory_max_seqlen,
  53. const size_t headdim,
  54. const int timestep,
  55. const int rotary_embedding_dim,
  56. const bool neox_rotary_style,
  57. T *q_ptr,
  58. T *k_ptr,
  59. T *v_ptr,
  60. T *k_cache_ptr,
  61. T *v_cache_ptr,
  62. int *length_per_sample,
  63. T *out_ptr) {
  64. // Reset the parameters
  65. memset(&params, 0, sizeof(params));
  66. params.q = q_ptr;
  67. params.k = k_ptr;
  68. params.v = v_ptr;
  69. params.q_bias = nullptr;
  70. params.k_bias = nullptr;
  71. params.v_bias = nullptr;
  72. params.k_cache = k_cache_ptr;
  73. params.v_cache = v_cache_ptr;
  74. params.out = out_ptr;
  75. params.cache_indir = nullptr;
  76. params.stride = 0;
  77. params.batch_size = batch_size;
  78. params.beam_width = 1;
  79. params.memory_max_len = memory_max_seqlen;
  80. params.num_heads = nheads;
  81. params.hidden_size_per_head = headdim;
  82. params.rotary_embedding_dim = rotary_embedding_dim;
  83. params.neox_rotary_style = neox_rotary_style;
  84. params.timestep = timestep;
  85. params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
  86. params.total_padding_tokens = nullptr;
  87. params.masked_tokens = nullptr;
  88. params.prefix_prompt_lengths = nullptr;
  89. // params.max_prefix_prompt_length = memory_max_seqlen; // TODO: waht should this be?
  90. params.max_prefix_prompt_length = 0; // TODO: waht should this be?
  91. params.relative_attention_bias = nullptr;
  92. params.relative_attention_bias_stride = 0;
  93. params.cross_attention_out = nullptr;
  94. params.max_decoder_seq_len = 0;
  95. params.is_return_cross_attentions = false;
  96. params.finished = nullptr;
  97. params.memory_length_per_sample = nullptr;
  98. params.length_per_sample = length_per_sample;
  99. }
  100. torch::Tensor single_query_attention(const torch::Tensor q,
  101. const torch::Tensor k,
  102. const torch::Tensor v,
  103. torch::Tensor k_cache,
  104. torch::Tensor v_cache,
  105. c10::optional<const torch::Tensor> length_per_sample_,
  106. const int timestep,
  107. const int rotary_embedding_dim = 0,
  108. const bool neox_rotary_style=true) {
  109. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
  110. int batch_size = v_cache.size(0);
  111. int nheads = v_cache.size(1);
  112. int memory_max_seqlen = v_cache.size(2);
  113. int headdim = v_cache.size(3);
  114. CHECK_SHAPE(q, batch_size, nheads, headdim);
  115. CHECK_SHAPE(k, batch_size, nheads, headdim);
  116. CHECK_SHAPE(v, batch_size, nheads, headdim);
  117. // TODO: Check shape of k_cache: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
  118. // TODO: avoid contiguous requirment by storing the stride
  119. CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v);
  120. CHECK_CONTIGUOUS(v_cache);
  121. if (length_per_sample_.has_value()) {
  122. auto length_per_sample = length_per_sample_.value();
  123. CHECK_DEVICE(length_per_sample);
  124. CHECK_SHAPE(length_per_sample, batch_size);
  125. CHECK_CONTIGUOUS(length_per_sample);
  126. TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
  127. }
  128. torch::Tensor out = torch::empty_like(q);
  129. DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] {
  130. using DataType = typename SATypeConverter<scalar_t>::Type;
  131. Masked_multihead_attention_params<DataType> params;
  132. set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
  133. rotary_embedding_dim, neox_rotary_style,
  134. reinterpret_cast<DataType*>(q.data_ptr()),
  135. reinterpret_cast<DataType*>(k.data_ptr()),
  136. reinterpret_cast<DataType*>(v.data_ptr()),
  137. reinterpret_cast<DataType*>(k_cache.data_ptr()),
  138. reinterpret_cast<DataType*>(v_cache.data_ptr()),
  139. length_per_sample_.has_value()
  140. ? length_per_sample_.value().data_ptr<int>() : nullptr,
  141. reinterpret_cast<DataType*>(out.data_ptr()));
  142. auto stream = at::cuda::getCurrentCUDAStream();
  143. masked_multihead_attention(params, stream);
  144. });
  145. return out;
  146. }
  147. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  148. m.def("single_query_attention", &single_query_attention, "Attention with a single query",
  149. py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
  150. py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
  151. py::arg("neox_rotary_style")=true);
  152. }