decoder_masked_multihead_attention.h 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. // Downloaded from from FasterTransformer v5.2.1
  2. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
  3. /*
  4. * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #pragma once
  19. #include "cuda_bf16_wrapper.h"
  20. #include <cuda_fp16.h>
  21. #include <cuda_runtime_api.h>
  22. #include <stdint.h>
  23. #include <stdio.h>
  24. #include <stdlib.h>
  25. ////////////////////////////////////////////////////////////////////////////////////////////////////
  26. #define CHECK_CUDA(call) \
  27. do { \
  28. cudaError_t status_ = call; \
  29. if (status_ != cudaSuccess) { \
  30. fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
  31. exit(1); \
  32. } \
  33. } while (0)
  34. ////////////////////////////////////////////////////////////////////////////////////////////////////
  35. // The structure of parameters for the masked multihead attention kernel.
  36. //
  37. // We use the following terminology to describe the different dimensions.
  38. //
  39. // B: Batch size (number of sequences),
  40. // L: Sequence length,
  41. // D: Hidden dimension,
  42. // H: Number of heads,
  43. // Dh: Hidden dimension per head - Dh = D / H.
  44. template<typename T>
  45. struct Multihead_attention_params_base {
  46. // The output buffer. Dimensions B x D.
  47. T* out = nullptr;
  48. // The input Qs and the associated bias. Dimensions B x D and D, resp.
  49. const T *q = nullptr, *q_bias = nullptr;
  50. // The input Ks and the associated bias. Dimensions B x D and D, resp.
  51. const T *k = nullptr, *k_bias = nullptr;
  52. // The input Vs and the associated bias. Dimensions B x D and D, resp.
  53. const T *v = nullptr, *v_bias = nullptr;
  54. // The cache for the Ks. The size must be at least B x L x D.
  55. T* k_cache = nullptr;
  56. // The cache for the Vs. The size must be at least B x L x D.
  57. T* v_cache = nullptr;
  58. // The indirections to use for cache when beam sampling.
  59. const int* cache_indir = nullptr;
  60. // Stride to handle the case when KQV is a single buffer
  61. int stride_q = 0;
  62. int stride_k = 0;
  63. int stride_v = 0;
  64. // The batch size.
  65. int batch_size = 0;
  66. // The beam width
  67. int beam_width = 0;
  68. // The sequence length.
  69. int memory_max_len = 0;
  70. // The number of heads (H).
  71. int num_heads = 0;
  72. int num_heads_kv = 0;
  73. int num_heads_q_kv_ratio = 0;
  74. // The hidden dimension per head (Dh).
  75. int hidden_size_per_head = 0;
  76. // The per-head latent space reserved for rotary embeddings.
  77. int rotary_embedding_dim = 0;
  78. bool neox_rotary_style = false;
  79. float rotary_base = 0.0f;
  80. // The maximum length of input sentences.
  81. int max_input_length = 0;
  82. // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
  83. int timestep = 0;
  84. // The current timestep of each sentences (support different timestep for different sentences)
  85. // The 1.f / sqrt(Dh). Computed on the host.
  86. float inv_sqrt_dh = 0.0f;
  87. // Used when we have some input context like gpt
  88. const int* total_padding_tokens = nullptr;
  89. const bool* masked_tokens = nullptr;
  90. const int* prefix_prompt_lengths = nullptr;
  91. int max_prefix_prompt_length = 0;
  92. const T* relative_attention_bias = nullptr;
  93. int relative_attention_bias_stride = 0;
  94. // The slope per head of linear position bias to attention score (H).
  95. const T* linear_bias_slopes = nullptr;
  96. const T* ia3_key_weights = nullptr;
  97. const T* ia3_value_weights = nullptr;
  98. const int* ia3_tasks = nullptr;
  99. const float* qkv_scale_out = nullptr;
  100. const float* attention_out_scale = nullptr;
  101. int int8_mode = 0;
  102. const T *rotary_cos = nullptr;
  103. const T *rotary_sin = nullptr;
  104. const int *nnz_head_idx = nullptr;
  105. int nnz_heads = 0;
  106. };
  107. template<typename T, bool CROSS_ATTENTION>
  108. struct Multihead_attention_params: public Multihead_attention_params_base<T> {
  109. // output cross attentions
  110. float* cross_attention_out = nullptr;
  111. int max_decoder_seq_len = 0;
  112. bool is_return_cross_attentions = false;
  113. // allows to exist attention eary
  114. bool* finished = nullptr;
  115. // required in case of cross attention
  116. // will need it here till if constexpr in c++17
  117. int* memory_length_per_sample = nullptr;
  118. // required in case of masked attention with different length
  119. const int* length_per_sample = nullptr;
  120. };
  121. template<typename T>
  122. struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
  123. // output cross attentions
  124. float* cross_attention_out = nullptr;
  125. int max_decoder_seq_len = 0;
  126. bool is_return_cross_attentions = false;
  127. // allows to exist attention eary
  128. bool* finished = nullptr;
  129. // required in case of cross attention
  130. int* memory_length_per_sample = nullptr;
  131. // required in case of masked attention with different length
  132. const int* length_per_sample = nullptr;
  133. };
  134. template<class T>
  135. using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
  136. template<class T>
  137. using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
  138. template<typename T>
  139. struct outputCrossAttentionParam {
  140. // max decoder output length
  141. int max_decoder_seq_len = 0;
  142. T* cross_attention_out = nullptr;
  143. bool is_return_cross_attentions = false;
  144. };
  145. ////////////////////////////////////////////////////////////////////////////////////////////////////
  146. void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
  147. void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
  148. #ifdef ENABLE_BF16
  149. void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
  150. const cudaStream_t& stream);
  151. #endif
  152. void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
  153. void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
  154. #ifdef ENABLE_BF16
  155. void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
  156. const cudaStream_t& stream);
  157. #endif
  158. ////////////////////////////////////////////////////////////////////////////////////////////////////