dtype_bfloat16.cuh 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
  4. * Copyright (c) 2023, The PygmalionAI team.
  5. * Copyright (c) 2023, The vLLM team.
  6. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #pragma once
  21. #include "attention_generic.cuh"
  22. #include "dtype_float32.cuh"
  23. #ifndef USE_ROCM
  24. #include <cuda_bf16.h>
  25. #include <cuda_fp16.h>
  26. #else
  27. #include <hip/hip_bf16.h>
  28. #include <hip/hip_fp16.h>
  29. typedef __hip_bfloat162 __nv_bfloat162;
  30. typedef __hip_bfloat16 __nv_bfloat16;
  31. #endif
  32. #include <stdint.h>
  33. namespace aphrodite {
  34. // Define custom BF16 vector data types.
  35. struct bf16_4_t {
  36. __nv_bfloat162 x;
  37. __nv_bfloat162 y;
  38. };
  39. struct bf16_8_t {
  40. __nv_bfloat162 x;
  41. __nv_bfloat162 y;
  42. __nv_bfloat162 z;
  43. __nv_bfloat162 w;
  44. };
  45. // BF16 vector types for Q, K, V.
  46. template<>
  47. struct Vec<__nv_bfloat16, 1> {
  48. using Type = __nv_bfloat16;
  49. };
  50. template<>
  51. struct Vec<__nv_bfloat16, 2> {
  52. using Type = __nv_bfloat162;
  53. };
  54. template<>
  55. struct Vec<__nv_bfloat16, 4> {
  56. using Type = bf16_4_t;
  57. };
  58. template<>
  59. struct Vec<__nv_bfloat16, 8> {
  60. using Type = bf16_8_t;
  61. };
  62. // FP32 accumulator vector types corresponding to Vec.
  63. template<>
  64. struct FloatVec<__nv_bfloat16> {
  65. using Type = float;
  66. };
  67. template<>
  68. struct FloatVec<__nv_bfloat162> {
  69. using Type = float2;
  70. };
  71. template<>
  72. struct FloatVec<bf16_4_t> {
  73. using Type = Float4_;
  74. };
  75. template<>
  76. struct FloatVec<bf16_8_t> {
  77. using Type = Float8_;
  78. };
  79. // Utility functions for type conversions.
  80. inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
  81. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  82. assert(false);
  83. #else
  84. return __bfloat1622float2(val);
  85. #endif
  86. }
  87. inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
  88. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  89. assert(false);
  90. #else
  91. return __bfloat162bfloat162(val);
  92. #endif
  93. }
  94. // Vector addition.
  95. inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
  96. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  97. assert(false);
  98. #else
  99. #ifndef USE_ROCM
  100. return a + b;
  101. #else
  102. return __hadd(a, b);
  103. #endif
  104. #endif
  105. }
  106. inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
  107. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  108. assert(false);
  109. #else
  110. return __hadd2(a, b);
  111. #endif
  112. }
  113. inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
  114. bf16_4_t c;
  115. c.x = add(a.x, b.x);
  116. c.y = add(a.y, b.y);
  117. return c;
  118. }
  119. inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
  120. bf16_8_t c;
  121. c.x = add(a.x, b.x);
  122. c.y = add(a.y, b.y);
  123. c.z = add(a.z, b.z);
  124. c.w = add(a.w, b.w);
  125. return c;
  126. }
  127. inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
  128. float2 fa = bf1622float2(a);
  129. return add(fa, fb);
  130. }
  131. inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
  132. Float4_ fc;
  133. fc.x = add(a.x, fb.x);
  134. fc.y = add(a.y, fb.y);
  135. return fc;
  136. }
  137. inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
  138. Float8_ fc;
  139. fc.x = add(a.x, fb.x);
  140. fc.y = add(a.y, fb.y);
  141. fc.z = add(a.z, fb.z);
  142. fc.w = add(a.w, fb.w);
  143. return fc;
  144. }
  145. // Vector multiplication.
  146. template<>
  147. inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
  148. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  149. assert(false);
  150. #else
  151. return __hmul(a, b);
  152. #endif
  153. }
  154. template<>
  155. inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
  156. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  157. assert(false);
  158. #else
  159. return __hmul2(a, b);
  160. #endif
  161. }
  162. template<>
  163. inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
  164. return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
  165. }
  166. template<>
  167. inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
  168. bf16_4_t c;
  169. c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
  170. c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
  171. return c;
  172. }
  173. template<>
  174. inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
  175. __nv_bfloat162 s = bf162bf162(a);
  176. bf16_4_t c;
  177. c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
  178. c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
  179. return c;
  180. }
  181. template<>
  182. inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
  183. bf16_8_t c;
  184. c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
  185. c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
  186. c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
  187. c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
  188. return c;
  189. }
  190. template<>
  191. inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
  192. __nv_bfloat162 s = bf162bf162(a);
  193. bf16_8_t c;
  194. c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
  195. c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
  196. c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
  197. c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
  198. return c;
  199. }
  200. template<>
  201. inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
  202. float fa = __bfloat162float(a);
  203. float fb = __bfloat162float(b);
  204. return fa * fb;
  205. }
  206. template<>
  207. inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
  208. float2 fa = bf1622float2(a);
  209. float2 fb = bf1622float2(b);
  210. return mul<float2, float2, float2>(fa, fb);
  211. }
  212. template<>
  213. inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
  214. return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
  215. }
  216. template<>
  217. inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
  218. Float4_ fc;
  219. fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
  220. fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
  221. return fc;
  222. }
  223. template<>
  224. inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
  225. __nv_bfloat162 s = bf162bf162(a);
  226. Float4_ fc;
  227. fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
  228. fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
  229. return fc;
  230. }
  231. template<>
  232. inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
  233. Float8_ fc;
  234. fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
  235. fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
  236. fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
  237. fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
  238. return fc;
  239. }
  240. template<>
  241. inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
  242. __nv_bfloat162 s = bf162bf162(a);
  243. Float8_ fc;
  244. fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
  245. fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
  246. fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
  247. fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
  248. return fc;
  249. }
  250. // Vector fused multiply-add.
  251. inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
  252. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  253. assert(false);
  254. #else
  255. return __hfma2(a, b, c);
  256. #endif
  257. }
  258. inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
  259. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  260. assert(false);
  261. #else
  262. return __hfma2(bf162bf162(a), b, c);
  263. #endif
  264. }
  265. inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
  266. bf16_4_t d;
  267. d.x = fma(a.x, b.x, c.x);
  268. d.y = fma(a.y, b.y, c.y);
  269. return d;
  270. }
  271. inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
  272. __nv_bfloat162 s = bf162bf162(a);
  273. bf16_4_t d;
  274. d.x = fma(s, b.x, c.x);
  275. d.y = fma(s, b.y, c.y);
  276. return d;
  277. }
  278. inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
  279. bf16_8_t d;
  280. d.x = fma(a.x, b.x, c.x);
  281. d.y = fma(a.y, b.y, c.y);
  282. d.z = fma(a.z, b.z, c.z);
  283. d.w = fma(a.w, b.w, c.w);
  284. return d;
  285. }
  286. inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
  287. __nv_bfloat162 s = bf162bf162(a);
  288. bf16_8_t d;
  289. d.x = fma(s, b.x, c.x);
  290. d.y = fma(s, b.y, c.y);
  291. d.z = fma(s, b.z, c.z);
  292. d.w = fma(s, b.w, c.w);
  293. return d;
  294. }
  295. inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
  296. return __bfloat162float(a) * __bfloat162float(b) + fc;
  297. }
  298. inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
  299. float2 fa = bf1622float2(a);
  300. float2 fb = bf1622float2(b);
  301. return fma(fa, fb, fc);
  302. }
  303. inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
  304. return fma(bf162bf162(a), b, fc);
  305. }
  306. inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
  307. Float4_ fd;
  308. fd.x = fma(a.x, b.x, fc.x);
  309. fd.y = fma(a.y, b.y, fc.y);
  310. return fd;
  311. }
  312. inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
  313. __nv_bfloat162 s = bf162bf162(a);
  314. Float4_ fd;
  315. fd.x = fma(s, b.x, fc.x);
  316. fd.y = fma(s, b.y, fc.y);
  317. return fd;
  318. }
  319. inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
  320. Float8_ fd;
  321. fd.x = fma(a.x, b.x, fc.x);
  322. fd.y = fma(a.y, b.y, fc.y);
  323. fd.z = fma(a.z, b.z, fc.z);
  324. fd.w = fma(a.w, b.w, fc.w);
  325. return fd;
  326. }
  327. inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
  328. __nv_bfloat162 s = bf162bf162(a);
  329. Float8_ fd;
  330. fd.x = fma(s, b.x, fc.x);
  331. fd.y = fma(s, b.y, fc.y);
  332. fd.z = fma(s, b.z, fc.z);
  333. fd.w = fma(s, b.w, fc.w);
  334. return fd;
  335. }
  336. // Vector sum.
  337. template<>
  338. inline __device__ float sum(__nv_bfloat16 v) {
  339. return __bfloat162float(v);
  340. }
  341. template<>
  342. inline __device__ float sum(__nv_bfloat162 v) {
  343. float2 vf = bf1622float2(v);
  344. return vf.x + vf.y;
  345. }
  346. template<>
  347. inline __device__ float sum(bf16_4_t v) {
  348. return sum(v.x) + sum(v.y);
  349. }
  350. template<>
  351. inline __device__ float sum(bf16_8_t v) {
  352. return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
  353. }
  354. // From float32 to bfloat16.
  355. inline __device__ void from_float(__nv_bfloat16& dst, float src) {
  356. dst = __float2bfloat16(src);
  357. }
  358. inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
  359. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  360. assert(false);
  361. #else
  362. dst = __float22bfloat162_rn(src);
  363. #endif
  364. }
  365. inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
  366. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  367. assert(false);
  368. #else
  369. dst.x = __float22bfloat162_rn(src.x);
  370. dst.y = __float22bfloat162_rn(src.y);
  371. #endif
  372. }
  373. inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
  374. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  375. assert(false);
  376. #else
  377. dst.x = __float22bfloat162_rn(src.x);
  378. dst.y = __float22bfloat162_rn(src.y);
  379. dst.z = __float22bfloat162_rn(src.z);
  380. dst.w = __float22bfloat162_rn(src.w);
  381. #endif
  382. }
  383. // From bfloat16 to float32.
  384. inline __device__ float to_float(__nv_bfloat16 u) {
  385. return __bfloat162float(u);
  386. }
  387. // Zero-out a variable.
  388. inline __device__ void zero(__nv_bfloat16& dst) {
  389. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  390. assert(false);
  391. #else
  392. // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
  393. dst = __ushort_as_bfloat16((unsigned short)0x0000U);
  394. #endif
  395. }
  396. } // namespace aphrodite