dtype_float32.cuh 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #pragma once
  20. #include "attention_generic.cuh"
  21. #include <stdint.h>
  22. namespace aphrodite {
  23. // Define custom FP32 vector data types.
  24. struct Float4_ {
  25. float2 x;
  26. float2 y;
  27. };
  28. struct Float8_ {
  29. float2 x;
  30. float2 y;
  31. float2 z;
  32. float2 w;
  33. };
  34. // FP32 vector types for Q, K, V.
  35. template<>
  36. struct Vec<float, 1> {
  37. using Type = float;
  38. };
  39. template<>
  40. struct Vec<float, 2> {
  41. using Type = float2;
  42. };
  43. template<>
  44. struct Vec<float, 4> {
  45. using Type = float4;
  46. };
  47. // FP32 accumulator vector types corresponding to Vec.
  48. template<>
  49. struct FloatVec<float> {
  50. using Type = float;
  51. };
  52. template<>
  53. struct FloatVec<float2> {
  54. using Type = float2;
  55. };
  56. template<>
  57. struct FloatVec<float4> {
  58. using Type = float4;
  59. };
  60. // Vector addition.
  61. inline __device__ float add(float a, float b) {
  62. return a + b;
  63. }
  64. inline __device__ float2 add(float2 a, float2 b) {
  65. float2 c;
  66. c.x = add(a.x, b.x);
  67. c.y = add(a.y, b.y);
  68. return c;
  69. }
  70. inline __device__ float4 add(float4 a, float4 b) {
  71. float4 c;
  72. c.x = add(a.x, b.x);
  73. c.y = add(a.y, b.y);
  74. c.z = add(a.z, b.z);
  75. c.w = add(a.w, b.w);
  76. return c;
  77. }
  78. inline __device__ Float4_ add(Float4_ a, Float4_ b) {
  79. Float4_ c;
  80. c.x = add(a.x, b.x);
  81. c.y = add(a.y, b.y);
  82. return c;
  83. }
  84. // Vector multiplication.
  85. template<>
  86. inline __device__ float mul<float, float>(float a, float b) {
  87. return a * b;
  88. }
  89. template<>
  90. inline __device__ float2 mul(float2 a, float2 b) {
  91. float2 c;
  92. c.x = a.x * b.x;
  93. c.y = a.y * b.y;
  94. return c;
  95. }
  96. template<>
  97. inline __device__ float2 mul(float a, float2 b) {
  98. float2 c;
  99. c.x = a * b.x;
  100. c.y = a * b.y;
  101. return c;
  102. }
  103. template<>
  104. inline __device__ float4 mul(float4 a, float4 b) {
  105. float4 c;
  106. c.x = a.x * b.x;
  107. c.y = a.y * b.y;
  108. c.z = a.z * b.z;
  109. c.w = a.w * b.w;
  110. return c;
  111. }
  112. template<>
  113. inline __device__ float4 mul(float a, float4 b) {
  114. float4 c;
  115. c.x = a * b.x;
  116. c.y = a * b.y;
  117. c.z = a * b.z;
  118. c.w = a * b.w;
  119. return c;
  120. }
  121. // Vector fused multiply-add.
  122. inline __device__ float fma(float a, float b, float c) {
  123. return a * b + c;
  124. }
  125. inline __device__ float2 fma(float2 a, float2 b, float2 c) {
  126. float2 d;
  127. d.x = fma(a.x, b.x, c.x);
  128. d.y = fma(a.y, b.y, c.y);
  129. return d;
  130. }
  131. inline __device__ float2 fma(float a, float2 b, float2 c) {
  132. float2 d;
  133. d.x = fma(a, b.x, c.x);
  134. d.y = fma(a, b.y, c.y);
  135. return d;
  136. }
  137. inline __device__ float4 fma(float4 a, float4 b, float4 c) {
  138. float4 d;
  139. d.x = fma(a.x, b.x, c.x);
  140. d.y = fma(a.y, b.y, c.y);
  141. d.z = fma(a.z, b.z, c.z);
  142. d.w = fma(a.w, b.w, c.w);
  143. return d;
  144. }
  145. inline __device__ float4 fma(float a, float4 b, float4 c) {
  146. float4 d;
  147. d.x = fma(a, b.x, c.x);
  148. d.y = fma(a, b.y, c.y);
  149. d.z = fma(a, b.z, c.z);
  150. d.w = fma(a, b.w, c.w);
  151. return d;
  152. }
  153. inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
  154. Float4_ d;
  155. d.x = fma(a, b.x, c.x);
  156. d.y = fma(a, b.y, c.y);
  157. return d;
  158. }
  159. inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
  160. Float8_ d;
  161. d.x = fma(a, b.x, c.x);
  162. d.y = fma(a, b.y, c.y);
  163. d.z = fma(a, b.z, c.z);
  164. d.w = fma(a, b.w, c.w);
  165. return d;
  166. }
  167. // Vector sum.
  168. template<>
  169. inline __device__ float sum(float v) {
  170. return v;
  171. }
  172. template<>
  173. inline __device__ float sum(float2 v) {
  174. return v.x + v.y;
  175. }
  176. template<>
  177. inline __device__ float sum(float4 v) {
  178. return v.x + v.y + v.z + v.w;
  179. }
  180. template<>
  181. inline __device__ float sum(Float4_ v) {
  182. return v.x.x + v.x.y + v.y.x + v.y.y;
  183. }
  184. template<>
  185. inline __device__ float sum(Float8_ v) {
  186. return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
  187. }
  188. // Vector dot product.
  189. inline __device__ float dot(float a, float b) {
  190. return a * b;
  191. }
  192. inline __device__ float dot(float2 a, float2 b) {
  193. float2 c = mul<float2, float2, float2>(a, b);
  194. return c.x + c.y;
  195. }
  196. inline __device__ float dot(Float4_ a, Float4_ b) {
  197. float2 acc = mul<float2, float2, float2>(a.x, b.x);
  198. acc = fma(a.y, b.y, acc);
  199. return acc.x + acc.y;
  200. }
  201. inline __device__ float dot(Float8_ a, Float8_ b) {
  202. float2 acc = mul<float2, float2, float2>(a.x, b.x);
  203. acc = fma(a.y, b.y, acc);
  204. acc = fma(a.z, b.z, acc);
  205. acc = fma(a.w, b.w, acc);
  206. return acc.x + acc.y;
  207. }
  208. // From float to float.
  209. inline __device__ void from_float(float& dst, float src) {
  210. dst = src;
  211. }
  212. inline __device__ void from_float(float2& dst, float2 src) {
  213. dst = src;
  214. }
  215. inline __device__ void from_float(float4& dst, float4 src) {
  216. dst = src;
  217. }
  218. // From float to float.
  219. inline __device__ float to_float(float u) {
  220. return u;
  221. }
  222. inline __device__ float2 to_float(float2 u) {
  223. return u;
  224. }
  225. inline __device__ float4 to_float(float4 u) {
  226. return u;
  227. }
  228. inline __device__ Float4_ to_float(Float4_ u) {
  229. return u;
  230. }
  231. inline __device__ Float8_ to_float(Float8_ u) {
  232. return u;
  233. }
  234. // Zero-out a variable.
  235. inline __device__ void zero(float& dst) {
  236. dst = 0.f;
  237. }
  238. } // namespace aphrodite