decoder_masked_multihead_attention.h 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. // Downloaded from from FasterTransformer v5.2.1
  2. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
  3. /*
  4. * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #pragma once
  19. #include "cuda_bf16_wrapper.h"
  20. #include <cuda_fp16.h>
  21. #include <cuda_runtime_api.h>
  22. #include <stdint.h>
  23. #include <stdio.h>
  24. #include <stdlib.h>
  25. ////////////////////////////////////////////////////////////////////////////////////////////////////
  26. #define CHECK_CUDA(call) \
  27. do { \
  28. cudaError_t status_ = call; \
  29. if (status_ != cudaSuccess) { \
  30. fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
  31. exit(1); \
  32. } \
  33. } while (0)
  34. ////////////////////////////////////////////////////////////////////////////////////////////////////
  35. // The structure of parameters for the masked multihead attention kernel.
  36. //
  37. // We use the following terminology to describe the different dimensions.
  38. //
  39. // B: Batch size (number of sequences),
  40. // L: Sequence length,
  41. // D: Hidden dimension,
  42. // H: Number of heads,
  43. // Dh: Hidden dimension per head - Dh = D / H.
  44. template<typename T>
  45. struct Multihead_attention_params_base {
  46. // The output buffer. Dimensions B x D.
  47. T* out = nullptr;
  48. // The input Qs and the associated bias. Dimensions B x D and D, resp.
  49. const T *q = nullptr, *q_bias = nullptr;
  50. // The input Ks and the associated bias. Dimensions B x D and D, resp.
  51. const T *k = nullptr, *k_bias = nullptr;
  52. // The input Vs and the associated bias. Dimensions B x D and D, resp.
  53. const T *v = nullptr, *v_bias = nullptr;
  54. // The cache for the Ks. The size must be at least B x L x D.
  55. T* k_cache = nullptr;
  56. // The cache for the Vs. The size must be at least B x L x D.
  57. T* v_cache = nullptr;
  58. // The indirections to use for cache when beam sampling.
  59. const int* cache_indir = nullptr;
  60. // Stride to handle the case when KQV is a single buffer
  61. int stride = 0;
  62. // The batch size.
  63. int batch_size = 0;
  64. // The beam width
  65. int beam_width = 0;
  66. // The sequence length.
  67. int memory_max_len = 0;
  68. // The number of heads (H).
  69. int num_heads = 0;
  70. // The hidden dimension per head (Dh).
  71. int hidden_size_per_head = 0;
  72. // The per-head latent space reserved for rotary embeddings.
  73. int rotary_embedding_dim = 0;
  74. bool neox_rotary_style = false;
  75. // The maximum length of input sentences.
  76. int max_input_length = 0;
  77. // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
  78. int timestep = 0;
  79. // The current timestep of each sentences (support different timestep for different sentences)
  80. // The 1.f / sqrt(Dh). Computed on the host.
  81. float inv_sqrt_dh = 0.0f;
  82. // Used when we have some input context like gpt
  83. const int* total_padding_tokens = nullptr;
  84. const bool* masked_tokens = nullptr;
  85. const int* prefix_prompt_lengths = nullptr;
  86. int max_prefix_prompt_length = 0;
  87. const T* relative_attention_bias = nullptr;
  88. int relative_attention_bias_stride = 0;
  89. // The slope per head of linear position bias to attention score (H).
  90. const T* linear_bias_slopes = nullptr;
  91. const T* ia3_key_weights = nullptr;
  92. const T* ia3_value_weights = nullptr;
  93. const int* ia3_tasks = nullptr;
  94. const float* qkv_scale_out = nullptr;
  95. const float* attention_out_scale = nullptr;
  96. int int8_mode = 0;
  97. };
  98. template<typename T, bool CROSS_ATTENTION>
  99. struct Multihead_attention_params: public Multihead_attention_params_base<T> {
  100. // output cross attentions
  101. float* cross_attention_out = nullptr;
  102. int max_decoder_seq_len = 0;
  103. bool is_return_cross_attentions = false;
  104. // allows to exist attention eary
  105. bool* finished = nullptr;
  106. // required in case of cross attention
  107. // will need it here till if constexpr in c++17
  108. int* memory_length_per_sample = nullptr;
  109. // required in case of masked attention with different length
  110. const int* length_per_sample = nullptr;
  111. };
  112. template<typename T>
  113. struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
  114. // output cross attentions
  115. float* cross_attention_out = nullptr;
  116. int max_decoder_seq_len = 0;
  117. bool is_return_cross_attentions = false;
  118. // allows to exist attention eary
  119. bool* finished = nullptr;
  120. // required in case of cross attention
  121. int* memory_length_per_sample = nullptr;
  122. // required in case of masked attention with different length
  123. const int* length_per_sample = nullptr;
  124. };
  125. template<class T>
  126. using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
  127. template<class T>
  128. using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
  129. template<typename T>
  130. struct outputCrossAttentionParam {
  131. // max decoder output length
  132. int max_decoder_seq_len = 0;
  133. T* cross_attention_out = nullptr;
  134. bool is_return_cross_attentions = false;
  135. };
  136. ////////////////////////////////////////////////////////////////////////////////////////////////////
  137. void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
  138. void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
  139. #ifdef ENABLE_BF16
  140. void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
  141. const cudaStream_t& stream);
  142. #endif
  143. void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
  144. void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
  145. #ifdef ENABLE_BF16
  146. void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
  147. const cudaStream_t& stream);
  148. #endif
  149. ////////////////////////////////////////////////////////////////////////////////////////////////////