dtype_float16.cuh 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
  4. * Copyright (c) 2023, The PygmalionAI team.
  5. * Copyright (c) 2023, The vLLM team.
  6. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #pragma once
  21. #include "attention_generic.cuh"
  22. #include "dtype_float32.cuh"
  23. #include <stdint.h>
  24. namespace aphrodite {
  25. // FP16 vector types for Q, K, V.
  26. template<>
  27. struct Vec<uint16_t, 1> {
  28. using Type = uint16_t;
  29. };
  30. template<>
  31. struct Vec<uint16_t, 2> {
  32. using Type = uint32_t;
  33. };
  34. template<>
  35. struct Vec<uint16_t, 4> {
  36. using Type = uint2;
  37. };
  38. template<>
  39. struct Vec<uint16_t, 8> {
  40. using Type = uint4;
  41. };
  42. // FP32 accumulator vector types corresponding to Vec.
  43. template<>
  44. struct FloatVec<uint16_t> {
  45. using Type = float;
  46. };
  47. template<>
  48. struct FloatVec<uint32_t> {
  49. using Type = float2;
  50. };
  51. template<>
  52. struct FloatVec<uint2> {
  53. using Type = Float4_;
  54. };
  55. template<>
  56. struct FloatVec<uint4> {
  57. using Type = Float8_;
  58. };
  59. // Utility functions for type conversions.
  60. inline __device__ uint32_t h0_h0(uint16_t a) {
  61. uint32_t b;
  62. asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
  63. return b;
  64. }
  65. inline __device__ float half_to_float(uint16_t h) {
  66. float f;
  67. asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
  68. return f;
  69. }
  70. inline __device__ float2 half2_to_float2(uint32_t v) {
  71. uint16_t lo, hi;
  72. asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
  73. return make_float2(half_to_float(lo), half_to_float(hi));
  74. }
  75. inline __device__ uint16_t float_to_half(float f) {
  76. union {
  77. uint32_t u32;
  78. uint16_t u16[2];
  79. } tmp;
  80. asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
  81. return tmp.u16[0];
  82. }
  83. inline __device__ uint32_t float2_to_half2(float2 f) {
  84. union {
  85. uint32_t u32;
  86. uint16_t u16[2];
  87. } tmp;
  88. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  89. asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
  90. #else
  91. asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
  92. asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
  93. #endif
  94. return tmp.u32;
  95. }
  96. // Vector addition.
  97. inline __device__ uint16_t add(uint16_t a, uint16_t b) {
  98. uint16_t c;
  99. asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
  100. return c;
  101. }
  102. inline __device__ uint32_t add(uint32_t a, uint32_t b) {
  103. uint32_t c;
  104. asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
  105. return c;
  106. }
  107. inline __device__ uint2 add(uint2 a, uint2 b) {
  108. uint2 c;
  109. c.x = add(a.x, b.x);
  110. c.y = add(a.y, b.y);
  111. return c;
  112. }
  113. inline __device__ uint4 add(uint4 a, uint4 b) {
  114. uint4 c;
  115. c.x = add(a.x, b.x);
  116. c.y = add(a.y, b.y);
  117. c.z = add(a.z, b.z);
  118. c.w = add(a.w, b.w);
  119. return c;
  120. }
  121. inline __device__ float2 add(uint32_t a, float2 fb) {
  122. float2 fa = half2_to_float2(a);
  123. return add(fa, fb);
  124. }
  125. inline __device__ Float4_ add(uint2 a, Float4_ fb) {
  126. Float4_ fc;
  127. fc.x = add(a.x, fb.x);
  128. fc.y = add(a.y, fb.y);
  129. return fc;
  130. }
  131. inline __device__ Float8_ add(uint4 a, Float8_ fb) {
  132. Float8_ fc;
  133. fc.x = add(a.x, fb.x);
  134. fc.y = add(a.y, fb.y);
  135. fc.z = add(a.z, fb.z);
  136. fc.w = add(a.w, fb.w);
  137. return fc;
  138. }
  139. // Vector multiplication.
  140. template<>
  141. inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
  142. uint16_t c;
  143. asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
  144. return c;
  145. }
  146. template<>
  147. inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
  148. uint32_t c;
  149. asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
  150. return c;
  151. }
  152. template<>
  153. inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
  154. return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
  155. }
  156. template<>
  157. inline __device__ uint2 mul(uint2 a, uint2 b) {
  158. uint2 c;
  159. c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
  160. c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
  161. return c;
  162. }
  163. template<>
  164. inline __device__ uint2 mul(uint16_t a, uint2 b) {
  165. uint32_t s = h0_h0(a);
  166. uint2 c;
  167. c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
  168. c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
  169. return c;
  170. }
  171. template<>
  172. inline __device__ uint4 mul(uint4 a, uint4 b) {
  173. uint4 c;
  174. c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
  175. c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
  176. c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
  177. c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
  178. return c;
  179. }
  180. template<>
  181. inline __device__ uint4 mul(uint16_t a, uint4 b) {
  182. uint32_t s = h0_h0(a);
  183. uint4 c;
  184. c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
  185. c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
  186. c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
  187. c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
  188. return c;
  189. }
  190. template<>
  191. inline __device__ float mul(uint16_t a, uint16_t b) {
  192. float fa = half_to_float(a);
  193. float fb = half_to_float(b);
  194. return fa * fb;
  195. }
  196. template<>
  197. inline __device__ float2 mul(uint32_t a, uint32_t b) {
  198. float2 fa = half2_to_float2(a);
  199. float2 fb = half2_to_float2(b);
  200. return mul<float2, float2, float2>(fa, fb);
  201. }
  202. template<>
  203. inline __device__ float2 mul(uint16_t a, uint32_t b) {
  204. return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
  205. }
  206. template<>
  207. inline __device__ Float4_ mul(uint2 a, uint2 b) {
  208. Float4_ fc;
  209. fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
  210. fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
  211. return fc;
  212. }
  213. template<>
  214. inline __device__ Float4_ mul(uint16_t a, uint2 b) {
  215. uint32_t s = h0_h0(a);
  216. Float4_ fc;
  217. fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
  218. fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
  219. return fc;
  220. }
  221. template<>
  222. inline __device__ Float8_ mul(uint4 a, uint4 b) {
  223. Float8_ fc;
  224. fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
  225. fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
  226. fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
  227. fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
  228. return fc;
  229. }
  230. template<>
  231. inline __device__ Float8_ mul(uint16_t a, uint4 b) {
  232. uint32_t s = h0_h0(a);
  233. Float8_ fc;
  234. fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
  235. fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
  236. fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
  237. fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
  238. return fc;
  239. }
  240. // Vector fused multiply-add.
  241. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
  242. uint32_t d;
  243. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
  244. return d;
  245. }
  246. inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
  247. return fma(h0_h0(a), b, c);
  248. }
  249. inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
  250. uint2 d;
  251. d.x = fma(a.x, b.x, c.x);
  252. d.y = fma(a.y, b.y, c.y);
  253. return d;
  254. }
  255. inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
  256. uint32_t s = h0_h0(a);
  257. uint2 d;
  258. d.x = fma(s, b.x, c.x);
  259. d.y = fma(s, b.y, c.y);
  260. return d;
  261. }
  262. inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
  263. uint4 d;
  264. d.x = fma(a.x, b.x, c.x);
  265. d.y = fma(a.y, b.y, c.y);
  266. d.z = fma(a.z, b.z, c.z);
  267. d.w = fma(a.w, b.w, c.w);
  268. return d;
  269. }
  270. inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
  271. uint32_t s = h0_h0(a);
  272. uint4 d;
  273. d.x = fma(s, b.x, c.x);
  274. d.y = fma(s, b.y, c.y);
  275. d.z = fma(s, b.z, c.z);
  276. d.w = fma(s, b.w, c.w);
  277. return d;
  278. }
  279. inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
  280. float fa = half_to_float(a);
  281. float fb = half_to_float(b);
  282. return fa * fb + fc;
  283. }
  284. inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
  285. float2 fa = half2_to_float2(a);
  286. float2 fb = half2_to_float2(b);
  287. return fma(fa, fb, fc);
  288. }
  289. inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
  290. return fma(h0_h0(a), b, fc);
  291. }
  292. inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
  293. Float4_ fd;
  294. fd.x = fma(a.x, b.x, fc.x);
  295. fd.y = fma(a.y, b.y, fc.y);
  296. return fd;
  297. }
  298. inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
  299. uint32_t s = h0_h0(a);
  300. Float4_ fd;
  301. fd.x = fma(s, b.x, fc.x);
  302. fd.y = fma(s, b.y, fc.y);
  303. return fd;
  304. }
  305. inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
  306. Float8_ fd;
  307. fd.x = fma(a.x, b.x, fc.x);
  308. fd.y = fma(a.y, b.y, fc.y);
  309. fd.z = fma(a.z, b.z, fc.z);
  310. fd.w = fma(a.w, b.w, fc.w);
  311. return fd;
  312. }
  313. inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
  314. uint32_t s = h0_h0(a);
  315. Float8_ fd;
  316. fd.x = fma(s, b.x, fc.x);
  317. fd.y = fma(s, b.y, fc.y);
  318. fd.z = fma(s, b.z, fc.z);
  319. fd.w = fma(s, b.w, fc.w);
  320. return fd;
  321. }
  322. // Vector sum.
  323. template<>
  324. inline __device__ float sum(uint16_t v) {
  325. return half_to_float(v);
  326. }
  327. template<>
  328. inline __device__ float sum(uint32_t v) {
  329. float2 tmp = half2_to_float2(v);
  330. return tmp.x + tmp.y;
  331. }
  332. template<>
  333. inline __device__ float sum(uint2 v) {
  334. uint32_t c = add(v.x, v.y);
  335. return sum(c);
  336. }
  337. template<>
  338. inline __device__ float sum(uint4 v) {
  339. uint32_t c = add(v.x, v.y);
  340. c = add(c, v.z);
  341. c = add(c, v.w);
  342. return sum(c);
  343. }
  344. // From float32 to float16.
  345. inline __device__ void from_float(uint16_t& dst, float src) {
  346. dst = float_to_half(src);
  347. }
  348. inline __device__ void from_float(uint32_t& dst, float2 src) {
  349. dst = float2_to_half2(src);
  350. }
  351. inline __device__ void from_float(uint2& dst, Float4_ src) {
  352. dst.x = float2_to_half2(src.x);
  353. dst.y = float2_to_half2(src.y);
  354. }
  355. inline __device__ void from_float(uint4& dst, Float8_ src) {
  356. dst.x = float2_to_half2(src.x);
  357. dst.y = float2_to_half2(src.y);
  358. dst.z = float2_to_half2(src.z);
  359. dst.w = float2_to_half2(src.w);
  360. }
  361. // From float16 to float32.
  362. inline __device__ float to_float(uint16_t u) {
  363. return half_to_float(u);
  364. }
  365. inline __device__ float2 to_float(uint32_t u) {
  366. return half2_to_float2(u);
  367. }
  368. inline __device__ Float4_ to_float(uint2 u) {
  369. Float4_ tmp;
  370. tmp.x = half2_to_float2(u.x);
  371. tmp.y = half2_to_float2(u.y);
  372. return tmp;
  373. }
  374. inline __device__ Float8_ to_float(uint4 u) {
  375. Float8_ tmp;
  376. tmp.x = half2_to_float2(u.x);
  377. tmp.y = half2_to_float2(u.y);
  378. tmp.z = half2_to_float2(u.z);
  379. tmp.w = half2_to_float2(u.w);
  380. return tmp;
  381. }
  382. // Zero-out a variable.
  383. inline __device__ void zero(uint16_t& dst) {
  384. dst = uint16_t(0);
  385. }
  386. } // namespace aphrodite