cuda_bf16_fallbacks.cuh 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. // Downloaded from from FasterTransformer v5.2.1
  2. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
  3. /*
  4. * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #pragma once
  19. #include "cuda_bf16_wrapper.h"
  20. #include <cuda_fp16.h>
  21. namespace fastertransformer {
  22. #ifdef ENABLE_BF16
  23. inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
  24. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  25. float2 f_val;
  26. f_val.x = __low2float(val);
  27. f_val.y = __high2float(val);
  28. return f_val;
  29. #else
  30. return __bfloat1622float2(val);
  31. #endif
  32. }
  33. inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
  34. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  35. float2 f_val;
  36. f_val.x = max(min(__low2float(val), 127.f), -128.f);
  37. f_val.y = max(min(__high2float(val), 127.f), -128.f);
  38. union { int8_t int8[2]; int16_t int16; };
  39. int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
  40. int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
  41. return int16;
  42. #else
  43. val = __hmin2(val, make_bfloat162(127., 127.));
  44. val = __hmax2(val, make_bfloat162(-128., -128.));
  45. union { int8_t int8[2]; int16_t int16; };
  46. int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
  47. int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
  48. return int16;
  49. #endif
  50. }
  51. inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
  52. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  53. return __floats2bfloat162_rn(val.x, val.y);
  54. #else
  55. return __float22bfloat162_rn(val);
  56. #endif
  57. }
  58. inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
  59. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  60. __nv_bfloat162 val2;
  61. val2.x = val;
  62. val2.y = val;
  63. return val2;
  64. #else
  65. return __bfloat162bfloat162(val);
  66. #endif
  67. }
  68. inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
  69. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  70. float fxl, fxh, fyl, fyh;
  71. fxl = __low2float(x);
  72. fxh = __high2float(x);
  73. fyl = __low2float(y);
  74. fyh = __high2float(y);
  75. return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
  76. #else
  77. return __hadd2(x, y);
  78. #endif
  79. }
  80. inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
  81. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  82. return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
  83. #else
  84. return __hadd(x, y);
  85. #endif
  86. }
  87. inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
  88. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  89. float fxl, fxh, fyl, fyh;
  90. fxl = __low2float(x);
  91. fxh = __high2float(x);
  92. fyl = __low2float(y);
  93. fyh = __high2float(y);
  94. return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
  95. #else
  96. return __hsub2(x, y);
  97. #endif
  98. }
  99. inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
  100. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  101. return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
  102. #else
  103. return __hsub(x, y);
  104. #endif
  105. }
  106. inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
  107. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  108. float fxl, fxh, fyl, fyh;
  109. fxl = __low2float(x);
  110. fxh = __high2float(x);
  111. fyl = __low2float(y);
  112. fyh = __high2float(y);
  113. return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
  114. #else
  115. return __hmul2(x, y);
  116. #endif
  117. }
  118. inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
  119. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  120. return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
  121. #else
  122. return __hmul(x, y);
  123. #endif
  124. }
  125. inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
  126. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  127. float fxl, fxh, fyl, fyh, fzl, fzh;
  128. fxl = __low2float(x);
  129. fxh = __high2float(x);
  130. fyl = __low2float(y);
  131. fyh = __high2float(y);
  132. fzl = __low2float(z);
  133. fzh = __high2float(z);
  134. return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
  135. #else
  136. return __hfma2(x, y, z);
  137. #endif
  138. }
  139. inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
  140. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  141. return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
  142. #else
  143. return __hfma(x, y, z);
  144. #endif
  145. }
  146. inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
  147. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  148. float fxl, fxh;
  149. fxl = __low2float(x);
  150. fxh = __high2float(x);;
  151. return __floats2bfloat162_rn(expf(fxl), expf(fxh));
  152. #else
  153. return h2exp(x);
  154. #endif
  155. }
  156. #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
  157. inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
  158. inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
  159. inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
  160. {
  161. __nv_bfloat162 t; t.x = x; t.y = y; return t;
  162. }
  163. #endif
  164. inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
  165. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  166. return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
  167. #else
  168. return a + b + c;
  169. #endif
  170. }
  171. inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
  172. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  173. return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
  174. #else
  175. return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
  176. #endif
  177. }
  178. inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
  179. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  180. float fal, fah, fbl, fbh, fcl, fch;
  181. fal = __low2float(a);
  182. fah = __high2float(a);
  183. fbl = __low2float(b);
  184. fbh = __high2float(b);
  185. fcl = __low2float(c);
  186. fch = __high2float(c);
  187. return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
  188. #else
  189. return a + b + c;
  190. #endif
  191. }
  192. inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
  193. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  194. return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
  195. #else
  196. return a * b * c;
  197. #endif
  198. }
  199. inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
  200. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  201. float fal, fah, fbl, fbh, fcl, fch;
  202. fal = __low2float(a);
  203. fah = __high2float(a);
  204. fbl = __low2float(b);
  205. fbh = __high2float(b);
  206. fcl = __low2float(c);
  207. fch = __high2float(c);
  208. return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
  209. #else
  210. return a * b * c;
  211. #endif
  212. }
  213. inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
  214. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  215. float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
  216. fal = __low2float(a);
  217. fah = __high2float(a);
  218. fbl = __low2float(b);
  219. fbh = __high2float(b);
  220. fcl = __low2float(c);
  221. fch = __high2float(c);
  222. fdl = __low2float(d);
  223. fdh = __high2float(d);
  224. return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
  225. #else
  226. return a * b * c + d;
  227. #endif
  228. }
  229. #endif // ENABLE_BF16
  230. } // namespace fastertransformer