dtype_float16.cuh 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
  4. * Copyright (c) 2023, The PygmalionAI team.
  5. * Copyright (c) 2023, The vLLM team.
  6. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #pragma once
  21. #include "attention_generic.cuh"
  22. #include "dtype_float32.cuh"
  23. #ifdef USE_ROCM
  24. #include <hip/hip_fp16.h>
  25. #endif
  26. #include <stdint.h>
  27. namespace aphrodite {
  28. // FP16 vector types for Q, K, V.
  29. template<>
  30. struct Vec<uint16_t, 1> {
  31. using Type = uint16_t;
  32. };
  33. template<>
  34. struct Vec<uint16_t, 2> {
  35. using Type = uint32_t;
  36. };
  37. template<>
  38. struct Vec<uint16_t, 4> {
  39. using Type = uint2;
  40. };
  41. template<>
  42. struct Vec<uint16_t, 8> {
  43. using Type = uint4;
  44. };
  45. // FP32 accumulator vector types corresponding to Vec.
  46. template<>
  47. struct FloatVec<uint16_t> {
  48. using Type = float;
  49. };
  50. template<>
  51. struct FloatVec<uint32_t> {
  52. using Type = float2;
  53. };
  54. template<>
  55. struct FloatVec<uint2> {
  56. using Type = Float4_;
  57. };
  58. template<>
  59. struct FloatVec<uint4> {
  60. using Type = Float8_;
  61. };
  62. // Utility functions for type conversions.
  63. inline __device__ uint32_t h0_h0(uint16_t a) {
  64. #ifndef USE_ROCM
  65. uint32_t b;
  66. asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
  67. return b;
  68. #else
  69. union {
  70. uint32_t u32;
  71. uint16_t u16[2];
  72. } tmp;
  73. tmp.u16[0] = a;
  74. tmp.u16[1] = a;
  75. return tmp.u32;
  76. #endif
  77. }
  78. inline __device__ float half_to_float(uint16_t h) {
  79. float f;
  80. #ifndef USE_ROCM
  81. asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
  82. #else
  83. asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
  84. #endif
  85. return f;
  86. }
  87. inline __device__ float2 half2_to_float2(uint32_t v) {
  88. #ifndef USE_ROCM
  89. uint16_t lo, hi;
  90. asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
  91. return make_float2(half_to_float(lo), half_to_float(hi));
  92. #else
  93. union {
  94. uint32_t u32;
  95. uint16_t u16[2];
  96. } tmp;
  97. tmp.u32 = v;
  98. float2 ret;
  99. ret.x = half_to_float(tmp.u16[0]);
  100. ret.y = half_to_float(tmp.u16[1]);
  101. return ret;
  102. #endif
  103. }
  104. inline __device__ uint16_t float_to_half(float f) {
  105. union {
  106. uint32_t u32;
  107. uint16_t u16[2];
  108. } tmp;
  109. #ifndef USE_ROCM
  110. asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
  111. #else
  112. asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
  113. #endif
  114. return tmp.u16[0];
  115. }
  116. inline __device__ uint32_t float2_to_half2(float2 f) {
  117. union {
  118. uint32_t u32;
  119. uint16_t u16[2];
  120. } tmp;
  121. #ifndef USE_ROCM
  122. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  123. asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
  124. #else
  125. asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
  126. asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
  127. #endif
  128. #else
  129. tmp.u16[0] = float_to_half(f.x);
  130. tmp.u16[1] = float_to_half(f.y);
  131. #endif
  132. return tmp.u32;
  133. }
  134. // Vector addition.
  135. inline __device__ uint16_t add(uint16_t a, uint16_t b) {
  136. uint16_t c;
  137. #ifndef USE_ROCM
  138. asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
  139. #else
  140. asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
  141. #endif
  142. return c;
  143. }
  144. inline __device__ uint32_t add(uint32_t a, uint32_t b) {
  145. uint32_t c;
  146. #ifndef USE_ROCM
  147. asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
  148. #else
  149. asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
  150. #endif
  151. return c;
  152. }
  153. inline __device__ uint2 add(uint2 a, uint2 b) {
  154. uint2 c;
  155. c.x = add(a.x, b.x);
  156. c.y = add(a.y, b.y);
  157. return c;
  158. }
  159. inline __device__ uint4 add(uint4 a, uint4 b) {
  160. uint4 c;
  161. c.x = add(a.x, b.x);
  162. c.y = add(a.y, b.y);
  163. c.z = add(a.z, b.z);
  164. c.w = add(a.w, b.w);
  165. return c;
  166. }
  167. inline __device__ float2 add(uint32_t a, float2 fb) {
  168. float2 fa = half2_to_float2(a);
  169. return add(fa, fb);
  170. }
  171. inline __device__ Float4_ add(uint2 a, Float4_ fb) {
  172. Float4_ fc;
  173. fc.x = add(a.x, fb.x);
  174. fc.y = add(a.y, fb.y);
  175. return fc;
  176. }
  177. inline __device__ Float8_ add(uint4 a, Float8_ fb) {
  178. Float8_ fc;
  179. fc.x = add(a.x, fb.x);
  180. fc.y = add(a.y, fb.y);
  181. fc.z = add(a.z, fb.z);
  182. fc.w = add(a.w, fb.w);
  183. return fc;
  184. }
  185. // Vector multiplication.
  186. template<>
  187. inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
  188. uint16_t c;
  189. #ifndef USE_ROCM
  190. asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
  191. #else
  192. asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
  193. #endif
  194. return c;
  195. }
  196. template<>
  197. inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
  198. uint32_t c;
  199. #ifndef USE_ROCM
  200. asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
  201. #else
  202. asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
  203. #endif
  204. return c;
  205. }
  206. template<>
  207. inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
  208. return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
  209. }
  210. template<>
  211. inline __device__ uint2 mul(uint2 a, uint2 b) {
  212. uint2 c;
  213. c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
  214. c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
  215. return c;
  216. }
  217. template<>
  218. inline __device__ uint2 mul(uint16_t a, uint2 b) {
  219. uint32_t s = h0_h0(a);
  220. uint2 c;
  221. c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
  222. c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
  223. return c;
  224. }
  225. template<>
  226. inline __device__ uint4 mul(uint4 a, uint4 b) {
  227. uint4 c;
  228. c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
  229. c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
  230. c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
  231. c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
  232. return c;
  233. }
  234. template<>
  235. inline __device__ uint4 mul(uint16_t a, uint4 b) {
  236. uint32_t s = h0_h0(a);
  237. uint4 c;
  238. c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
  239. c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
  240. c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
  241. c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
  242. return c;
  243. }
  244. template<>
  245. inline __device__ float mul(uint16_t a, uint16_t b) {
  246. float fa = half_to_float(a);
  247. float fb = half_to_float(b);
  248. return fa * fb;
  249. }
  250. template<>
  251. inline __device__ float2 mul(uint32_t a, uint32_t b) {
  252. float2 fa = half2_to_float2(a);
  253. float2 fb = half2_to_float2(b);
  254. return mul<float2, float2, float2>(fa, fb);
  255. }
  256. template<>
  257. inline __device__ float2 mul(uint16_t a, uint32_t b) {
  258. return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
  259. }
  260. template<>
  261. inline __device__ Float4_ mul(uint2 a, uint2 b) {
  262. Float4_ fc;
  263. fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
  264. fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
  265. return fc;
  266. }
  267. template<>
  268. inline __device__ Float4_ mul(uint16_t a, uint2 b) {
  269. uint32_t s = h0_h0(a);
  270. Float4_ fc;
  271. fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
  272. fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
  273. return fc;
  274. }
  275. template<>
  276. inline __device__ Float8_ mul(uint4 a, uint4 b) {
  277. Float8_ fc;
  278. fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
  279. fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
  280. fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
  281. fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
  282. return fc;
  283. }
  284. template<>
  285. inline __device__ Float8_ mul(uint16_t a, uint4 b) {
  286. uint32_t s = h0_h0(a);
  287. Float8_ fc;
  288. fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
  289. fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
  290. fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
  291. fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
  292. return fc;
  293. }
  294. // Vector fused multiply-add.
  295. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
  296. uint32_t d;
  297. #ifndef USE_ROCM
  298. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
  299. #else
  300. asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
  301. #endif
  302. return d;
  303. }
  304. inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
  305. return fma(h0_h0(a), b, c);
  306. }
  307. inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
  308. uint2 d;
  309. d.x = fma(a.x, b.x, c.x);
  310. d.y = fma(a.y, b.y, c.y);
  311. return d;
  312. }
  313. inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
  314. uint32_t s = h0_h0(a);
  315. uint2 d;
  316. d.x = fma(s, b.x, c.x);
  317. d.y = fma(s, b.y, c.y);
  318. return d;
  319. }
  320. inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
  321. uint4 d;
  322. d.x = fma(a.x, b.x, c.x);
  323. d.y = fma(a.y, b.y, c.y);
  324. d.z = fma(a.z, b.z, c.z);
  325. d.w = fma(a.w, b.w, c.w);
  326. return d;
  327. }
  328. inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
  329. uint32_t s = h0_h0(a);
  330. uint4 d;
  331. d.x = fma(s, b.x, c.x);
  332. d.y = fma(s, b.y, c.y);
  333. d.z = fma(s, b.z, c.z);
  334. d.w = fma(s, b.w, c.w);
  335. return d;
  336. }
  337. inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
  338. float fa = half_to_float(a);
  339. float fb = half_to_float(b);
  340. return fa * fb + fc;
  341. }
  342. inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
  343. float2 fa = half2_to_float2(a);
  344. float2 fb = half2_to_float2(b);
  345. return fma(fa, fb, fc);
  346. }
  347. inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
  348. return fma(h0_h0(a), b, fc);
  349. }
  350. inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
  351. Float4_ fd;
  352. fd.x = fma(a.x, b.x, fc.x);
  353. fd.y = fma(a.y, b.y, fc.y);
  354. return fd;
  355. }
  356. inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
  357. uint32_t s = h0_h0(a);
  358. Float4_ fd;
  359. fd.x = fma(s, b.x, fc.x);
  360. fd.y = fma(s, b.y, fc.y);
  361. return fd;
  362. }
  363. inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
  364. Float8_ fd;
  365. fd.x = fma(a.x, b.x, fc.x);
  366. fd.y = fma(a.y, b.y, fc.y);
  367. fd.z = fma(a.z, b.z, fc.z);
  368. fd.w = fma(a.w, b.w, fc.w);
  369. return fd;
  370. }
  371. inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
  372. uint32_t s = h0_h0(a);
  373. Float8_ fd;
  374. fd.x = fma(s, b.x, fc.x);
  375. fd.y = fma(s, b.y, fc.y);
  376. fd.z = fma(s, b.z, fc.z);
  377. fd.w = fma(s, b.w, fc.w);
  378. return fd;
  379. }
  380. // Vector sum.
  381. template<>
  382. inline __device__ float sum(uint16_t v) {
  383. return half_to_float(v);
  384. }
  385. template<>
  386. inline __device__ float sum(uint32_t v) {
  387. float2 tmp = half2_to_float2(v);
  388. return tmp.x + tmp.y;
  389. }
  390. template<>
  391. inline __device__ float sum(uint2 v) {
  392. uint32_t c = add(v.x, v.y);
  393. return sum(c);
  394. }
  395. template<>
  396. inline __device__ float sum(uint4 v) {
  397. uint32_t c = add(v.x, v.y);
  398. c = add(c, v.z);
  399. c = add(c, v.w);
  400. return sum(c);
  401. }
  402. // From float32 to float16.
  403. inline __device__ void from_float(uint16_t& dst, float src) {
  404. dst = float_to_half(src);
  405. }
  406. inline __device__ void from_float(uint32_t& dst, float2 src) {
  407. dst = float2_to_half2(src);
  408. }
  409. inline __device__ void from_float(uint2& dst, Float4_ src) {
  410. dst.x = float2_to_half2(src.x);
  411. dst.y = float2_to_half2(src.y);
  412. }
  413. inline __device__ void from_float(uint4& dst, Float8_ src) {
  414. dst.x = float2_to_half2(src.x);
  415. dst.y = float2_to_half2(src.y);
  416. dst.z = float2_to_half2(src.z);
  417. dst.w = float2_to_half2(src.w);
  418. }
  419. // From float16 to float32.
  420. inline __device__ float to_float(uint16_t u) {
  421. return half_to_float(u);
  422. }
  423. inline __device__ float2 to_float(uint32_t u) {
  424. return half2_to_float2(u);
  425. }
  426. inline __device__ Float4_ to_float(uint2 u) {
  427. Float4_ tmp;
  428. tmp.x = half2_to_float2(u.x);
  429. tmp.y = half2_to_float2(u.y);
  430. return tmp;
  431. }
  432. inline __device__ Float8_ to_float(uint4 u) {
  433. Float8_ tmp;
  434. tmp.x = half2_to_float2(u.x);
  435. tmp.y = half2_to_float2(u.y);
  436. tmp.z = half2_to_float2(u.z);
  437. tmp.w = half2_to_float2(u.w);
  438. return tmp;
  439. }
  440. // Zero-out a variable.
  441. inline __device__ void zero(uint16_t& dst) {
  442. dst = uint16_t(0);
  443. }
  444. } // namespace aphrodite