1
0

decoder_masked_multihead_attention.cu 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. // Adapted from from FasterTransformer v5.2.1
  2. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
  3. /*
  4. * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "decoder_masked_multihead_attention.h"
  19. #include "decoder_masked_multihead_attention_utils.h"
  20. #include "cuda_bf16_wrapper.h"
  21. #include <assert.h>
  22. #include <float.h>
  23. #include <type_traits>
  24. #include "decoder_masked_multihead_attention_template.hpp"
  25. ////////////////////////////////////////////////////////////////////////////////////////////////////
  26. #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
  27. size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
  28. auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
  29. THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
  30. cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
  31. dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \
  32. kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
  33. ////////////////////////////////////////////////////////////////////////////////////////////////////
  34. // !!! Specialize the launcher for Cross attention
  35. template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
  36. void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
  37. {
  38. constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
  39. constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
  40. int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
  41. // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
  42. if (tlength < 32) {
  43. MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
  44. }
  45. else if (tlength < 2048) {
  46. MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
  47. }
  48. else {
  49. MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
  50. }
  51. }
  52. ////////////////////////////////////////////////////////////////////////////////////////////////////
  53. #undef MMHA_LAUNCH_KERNEL
  54. template<typename T, typename KERNEL_PARAMS_TYPE>
  55. void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
  56. {
  57. switch (params.hidden_size_per_head) {
  58. case 32:
  59. mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
  60. break;
  61. case 48:
  62. mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
  63. break;
  64. case 64:
  65. mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
  66. break;
  67. case 80:
  68. mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
  69. break;
  70. case 96:
  71. mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
  72. break;
  73. case 128:
  74. mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
  75. break;
  76. case 160:
  77. mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
  78. break;
  79. case 192:
  80. mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
  81. break;
  82. case 224:
  83. mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
  84. break;
  85. case 256:
  86. mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
  87. break;
  88. default:
  89. assert(false);
  90. }
  91. }
  92. ////////////////////////////////////////////////////////////////////////////////////////////////////
  93. void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
  94. {
  95. multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
  96. }
  97. ////////////////////////////////////////////////////////////////////////////////////////////////////
  98. void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
  99. {
  100. multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
  101. }
  102. ////////////////////////////////////////////////////////////////////////////////////////////////////
  103. #ifdef ENABLE_BF16
  104. void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
  105. const cudaStream_t& stream)
  106. {
  107. multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
  108. }
  109. #endif
  110. ////////////////////////////////////////////////////////////////////////////////////////////////////
  111. void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
  112. {
  113. multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
  114. }
  115. ////////////////////////////////////////////////////////////////////////////////////////////////////
  116. void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
  117. {
  118. multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
  119. }
  120. ////////////////////////////////////////////////////////////////////////////////////////////////////
  121. #ifdef ENABLE_BF16
  122. void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
  123. const cudaStream_t& stream)
  124. {
  125. multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
  126. }
  127. #endif
  128. ////////////////////////////////////////////////////////////////////////////////////////////////////