cache_kernels.cu 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. #include "quantization/int8_kvcache/quant_utils.cuh"
  7. #ifdef ENABLE_FP8_E5M2
  8. #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
  9. #endif
  10. #include <algorithm>
  11. #include <cassert>
  12. #include <map>
  13. #include <vector>
  14. enum kv_cache_dtype {
  15. AUTO,
  16. #ifdef ENABLE_FP8_E5M2
  17. FP8_E5M2,
  18. #endif
  19. INT8};
  20. #ifdef USE_ROCM
  21. #include <hip/hip_bf16.h>
  22. typedef __hip_bfloat16 __nv_bfloat16;
  23. #endif
  24. void swap_blocks(
  25. torch::Tensor& src,
  26. torch::Tensor& dst,
  27. const std::map<int64_t, int64_t>& block_mapping) {
  28. torch::Device src_device = src.device();
  29. torch::Device dst_device = dst.device();
  30. cudaMemcpyKind memcpy_type;
  31. if (src_device.is_cuda() && dst_device.is_cuda()) {
  32. TORCH_CHECK(
  33. src_device.index() == dst_device.index(),
  34. "src and dst must be on the same GPU");
  35. memcpy_type = cudaMemcpyDeviceToDevice;
  36. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  37. memcpy_type = cudaMemcpyDeviceToHost;
  38. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  39. memcpy_type = cudaMemcpyHostToDevice;
  40. } else {
  41. TORCH_CHECK(false, "Invalid device combination");
  42. }
  43. char *src_ptr = static_cast<char*>(src.data_ptr());
  44. char *dst_ptr = static_cast<char*>(dst.data_ptr());
  45. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  46. const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
  47. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  48. // NOTE: This can be slow if the number of blocks is large.
  49. for (const auto& pair : block_mapping) {
  50. int64_t src_block_number = pair.first;
  51. int64_t dst_block_number = pair.second;
  52. int64_t src_offset = src_block_number * block_size_in_bytes;
  53. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  54. cudaMemcpyAsync(
  55. dst_ptr + dst_offset,
  56. src_ptr + src_offset,
  57. block_size_in_bytes,
  58. memcpy_type,
  59. stream);
  60. }
  61. }
  62. namespace aphrodite {
  63. // Grid: (num_layers, num_pairs)
  64. template<typename scalar_t>
  65. __global__ void copy_blocks_kernel(
  66. int64_t* key_cache_ptrs,
  67. int64_t* value_cache_ptrs,
  68. const int64_t* __restrict__ block_mapping,
  69. const int numel_per_block) {
  70. const int layer_idx = blockIdx.x;
  71. const int pair_idx = blockIdx.y;
  72. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  73. scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  74. int64_t src_block_number = block_mapping[2 * pair_idx];
  75. int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
  76. const int64_t src_block_offset = src_block_number * numel_per_block;
  77. const int64_t dst_block_offset = dst_block_number * numel_per_block;
  78. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  79. int64_t src_offset = src_block_offset + i;
  80. int64_t dst_offset = dst_block_offset + i;
  81. key_cache[dst_offset] = key_cache[src_offset];
  82. }
  83. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  84. int64_t src_offset = src_block_offset + i;
  85. int64_t dst_offset = dst_block_offset + i;
  86. value_cache[dst_offset] = value_cache[src_offset];
  87. }
  88. }
  89. } // namespace aphrodite
  90. void copy_blocks(
  91. std::vector<torch::Tensor>& key_caches,
  92. std::vector<torch::Tensor>& value_caches,
  93. const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
  94. int num_layers = key_caches.size();
  95. TORCH_CHECK(num_layers == value_caches.size());
  96. if (num_layers == 0) {
  97. return;
  98. }
  99. torch::Device cache_device = key_caches[0].device();
  100. TORCH_CHECK(cache_device.is_cuda());
  101. // Create data structures for the kernel.
  102. // Create an array of pointers to the key and value caches.
  103. int64_t key_cache_ptrs[num_layers];
  104. int64_t value_cache_ptrs[num_layers];
  105. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  106. key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  107. value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  108. }
  109. // Create block mapping array.
  110. std::vector<int64_t> block_mapping_vec;
  111. for (const auto& pair : block_mapping) {
  112. int64_t src_block_number = pair.first;
  113. for (int64_t dst_block_number : pair.second) {
  114. block_mapping_vec.push_back(src_block_number);
  115. block_mapping_vec.push_back(dst_block_number);
  116. }
  117. }
  118. int64_t* block_mapping_array = block_mapping_vec.data();
  119. int num_pairs = block_mapping_vec.size() / 2;
  120. // Move the data structures to the GPU.
  121. // NOTE: This synchronizes the CPU and GPU.
  122. torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
  123. key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  124. torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
  125. value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  126. torch::Tensor block_mapping_tensor = torch::from_blob(
  127. block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
  128. // Launch the kernel.
  129. const int numel_per_block = key_caches[0][0].numel();
  130. dim3 grid(num_layers, num_pairs);
  131. dim3 block(std::min(1024, numel_per_block));
  132. const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  133. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  134. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  135. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  136. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  137. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  138. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  139. block_mapping_tensor.data_ptr<int64_t>(),
  140. numel_per_block);
  141. }));
  142. }
  143. namespace aphrodite {
  144. template<typename scalar_t, typename cache_t, kv_cache_dtype KV_CACHE_DTYPE>
  145. __global__ void reshape_and_cache_kernel(
  146. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  147. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  148. cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  149. cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  150. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  151. const int key_stride,
  152. const int value_stride,
  153. const int num_heads,
  154. const int head_size,
  155. const int block_size,
  156. const int x,
  157. const float k_scale,
  158. const float k_zp,
  159. const float v_scale,
  160. const float v_zp) {
  161. const int64_t token_idx = blockIdx.x;
  162. const int64_t slot_idx = slot_mapping[token_idx];
  163. if (slot_idx < 0) {
  164. // Padding token that should be ignored.
  165. return;
  166. }
  167. const int64_t block_idx = slot_idx / block_size;
  168. const int64_t block_offset = slot_idx % block_size;
  169. const int n = num_heads * head_size;
  170. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  171. const int64_t src_key_idx = token_idx * key_stride + i;
  172. const int64_t src_value_idx = token_idx * value_stride + i;
  173. const int head_idx = i / head_size;
  174. const int head_offset = i % head_size;
  175. const int x_idx = head_offset / x;
  176. const int x_offset = head_offset % x;
  177. const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  178. + head_idx * (head_size / x) * block_size * x
  179. + x_idx * block_size * x
  180. + block_offset * x
  181. + x_offset;
  182. const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
  183. + head_idx * head_size * block_size
  184. + head_offset * block_size
  185. + block_offset;
  186. scalar_t tgt_key = key[src_key_idx];
  187. scalar_t tgt_value = value[src_value_idx];
  188. if constexpr (KV_CACHE_DTYPE == INT8) {
  189. key_cache[tgt_key_idx] = int8::quant(tgt_key, k_scale, k_zp);
  190. value_cache[tgt_value_idx] = int8::quant(tgt_value, v_scale, v_zp);
  191. #ifdef ENABLE_FP8_E5M2
  192. } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
  193. key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
  194. value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
  195. #endif
  196. } else {
  197. key_cache[tgt_key_idx] = tgt_key;
  198. value_cache[tgt_value_idx] = tgt_value;
  199. }
  200. }
  201. }
  202. } // namespace aphrodite
  203. #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_CACHE_DTYPE) \
  204. aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, KV_CACHE_DTYPE><<<grid, block, 0, stream>>>( \
  205. reinterpret_cast<KV_T*>(key.data_ptr()), \
  206. reinterpret_cast<KV_T*>(value.data_ptr()), \
  207. reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
  208. reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
  209. slot_mapping.data_ptr<int64_t>(), \
  210. key_stride, \
  211. value_stride, \
  212. num_heads, \
  213. head_size, \
  214. block_size, \
  215. x, \
  216. k_scale, \
  217. k_zp, \
  218. v_scale, \
  219. v_zp);
  220. void reshape_and_cache(
  221. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  222. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  223. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  224. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  225. torch::Tensor& slot_mapping, // [num_tokens]
  226. const std::string& kv_cache_dtype,
  227. const float k_scale = 1.0f,
  228. const float k_zp = 0.0f,
  229. const float v_scale = 1.0f,
  230. const float v_zp = 0.0f)
  231. {
  232. int num_tokens = key.size(0);
  233. int num_heads = key.size(1);
  234. int head_size = key.size(2);
  235. int block_size = key_cache.size(3);
  236. int x = key_cache.size(4);
  237. int key_stride = key.stride(0);
  238. int value_stride = value.stride(0);
  239. dim3 grid(num_tokens);
  240. dim3 block(std::min(num_heads * head_size, 512));
  241. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  242. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  243. if (kv_cache_dtype == "auto") {
  244. if (key.dtype() == at::ScalarType::Float) {
  245. CALL_RESHAPE_AND_CACHE(float, float, AUTO);
  246. } else if (key.dtype() == at::ScalarType::Half) {
  247. CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, AUTO);
  248. } else if (key.dtype() == at::ScalarType::BFloat16) {
  249. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, AUTO);
  250. }
  251. #ifdef ENABLE_FP8_E5M2
  252. } else if (kv_cache_dtype == "fp8_e5m2") {
  253. if (key.dtype() == at::ScalarType::Float) {
  254. CALL_RESHAPE_AND_CACHE(float, uint8_t, FP8_E5M2);
  255. } else if (key.dtype() == at::ScalarType::Half) {
  256. CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, FP8_E5M2);
  257. } else if (key.dtype() == at::ScalarType::BFloat16) {
  258. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, FP8_E5M2);
  259. }
  260. #endif
  261. } else if (kv_cache_dtype == "int8") {
  262. if (key.dtype() == at::ScalarType::Float) {
  263. CALL_RESHAPE_AND_CACHE(float, int8_t, INT8);
  264. } else if (key.dtype() == at::ScalarType::Half) {
  265. CALL_RESHAPE_AND_CACHE(uint16_t, int8_t, INT8);
  266. } else if (key.dtype() == at::ScalarType::BFloat16) {
  267. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, int8_t, INT8);
  268. }
  269. } else {
  270. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  271. }
  272. }
  273. namespace aphrodite {
  274. // Grid: (num_blocks, block_size).
  275. template<typename scalar_t>
  276. __global__ void gather_cached_kv_kernel(
  277. scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  278. scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  279. const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  280. const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  281. const int* __restrict__ slot_mapping, // [num_tokens]
  282. const int key_stride,
  283. const int value_stride,
  284. const int num_heads,
  285. const int head_size,
  286. const int block_size,
  287. const int x) {
  288. const int token_idx = blockIdx.x;
  289. const int slot_idx = slot_mapping[token_idx];
  290. const int block_idx = slot_idx / block_size;
  291. const int block_offset = slot_idx % block_size;
  292. const int num_tokens = num_heads * head_size;
  293. for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
  294. const int tgt_key_idx = token_idx * key_stride + i;
  295. const int tgt_value_idx = token_idx * value_stride + i;
  296. const int head_idx = i / head_size;
  297. const int head_offset = i % head_size;
  298. const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
  299. const int x_offset = head_offset % x;
  300. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  301. + head_idx * (head_size / x) * block_size * x
  302. + x_idx * block_size * x
  303. + block_offset * x
  304. + x_offset;
  305. const int src_value_idx = block_idx * num_heads * head_size * block_size
  306. + head_idx * head_size * block_size
  307. + head_offset * block_size
  308. + block_offset;
  309. key[tgt_key_idx] = APHRODITE_LDG(&key_cache[src_key_idx]);
  310. value[tgt_value_idx] = APHRODITE_LDG(&value_cache[src_value_idx]);
  311. }
  312. }
  313. template <typename scalar_t>
  314. __global__ void gather_cached_kv_kernel_optimized(
  315. scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  316. scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  317. const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  318. const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  319. const int *__restrict__ slot_mapping, // [num_tokens]
  320. const int key_stride,
  321. const int value_stride,
  322. const int num_heads,
  323. const int head_size,
  324. const int block_size,
  325. const int x)
  326. {
  327. const int token_idx = blockIdx.x;
  328. const int slot_idx = slot_mapping[token_idx];
  329. const int block_idx = slot_idx / block_size;
  330. const int block_offset = slot_idx % block_size;
  331. const int dim = num_heads * head_size;
  332. assert(dim % 4 == 0); // this is true for known use cases
  333. const int unroll_factor = 4;
  334. const int unrolled_dim = dim / unroll_factor;
  335. for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
  336. {
  337. int tgt_key_indices[unroll_factor];
  338. int tgt_value_indices[unroll_factor];
  339. int src_key_indices[unroll_factor];
  340. int src_value_indices[unroll_factor];
  341. scalar_t keys_to_store[unroll_factor];
  342. scalar_t values_to_store[unroll_factor];
  343. #pragma unroll
  344. for (int j = 0; j < unroll_factor; ++j)
  345. {
  346. int index = i + j * unrolled_dim;
  347. const int tgt_key_idx = token_idx * key_stride + index;
  348. const int tgt_value_idx = token_idx * value_stride + index;
  349. const int head_idx = index / head_size;
  350. const int head_offset = index % head_size;
  351. const int x_idx = head_offset / x;
  352. const int x_offset = head_offset % x;
  353. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  354. + head_idx * (head_size / x) * block_size * x
  355. + x_idx * block_size * x
  356. + block_offset * x
  357. + x_offset;
  358. const int src_value_idx = block_idx * num_heads * head_size * block_size
  359. + head_idx * head_size * block_size
  360. + head_offset * block_size
  361. + block_offset;
  362. tgt_key_indices[j] = tgt_key_idx;
  363. tgt_value_indices[j] = tgt_value_idx;
  364. src_key_indices[j] = src_key_idx;
  365. src_value_indices[j] = src_value_idx;
  366. keys_to_store[j] = APHRODITE_LDG(&key_cache[src_key_idx]);
  367. values_to_store[j] = APHRODITE_LDG(&value_cache[src_value_idx]);
  368. }
  369. #pragma unroll
  370. for (int j = 0; j < unroll_factor; ++j)
  371. {
  372. key[tgt_key_indices[j]] = keys_to_store[j];
  373. value[tgt_value_indices[j]] = values_to_store[j];
  374. }
  375. }
  376. }
  377. } // namespace aphrodite
  378. void gather_cached_kv(
  379. torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
  380. torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
  381. torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
  382. torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
  383. torch::Tensor& slot_mapping) // [in] [num_tokens]
  384. {
  385. int num_tokens = key.size(0);
  386. int num_heads = key.size(1);
  387. int head_size = key.size(2);
  388. int block_size = key_cache.size(3);
  389. int x = key_cache.size(4);
  390. int key_stride = key.stride(0);
  391. int value_stride = value.stride(0);
  392. dim3 grid(num_tokens);
  393. dim3 block(std::min(num_heads * head_size, 512));
  394. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  395. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  396. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  397. key.scalar_type(),
  398. "gather_cached_kv_kernel_optimized",
  399. [&] {
  400. aphrodite::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
  401. key.data_ptr<scalar_t>(),
  402. value.data_ptr<scalar_t>(),
  403. key_cache.data_ptr<scalar_t>(),
  404. value_cache.data_ptr<scalar_t>(),
  405. slot_mapping.data_ptr<int>(),
  406. key_stride,
  407. value_stride,
  408. num_heads,
  409. head_size,
  410. block_size,
  411. x);
  412. });
  413. }
  414. namespace aphrodite {
  415. template<typename Tout, typename Tin>
  416. __global__ void convert_fp8_e5m2_kernel(
  417. const Tin* __restrict__ src_cache,
  418. Tout* __restrict__ dst_cache,
  419. const int64_t block_stride) {
  420. const int64_t block_idx = blockIdx.x;
  421. for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
  422. int64_t idx = block_idx * block_stride + i;
  423. #ifdef ENABLE_FP8_E5M2
  424. dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
  425. #else
  426. assert(false);
  427. #endif
  428. }
  429. }
  430. } // namespace aphrodite
  431. #define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
  432. aphrodite::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
  433. reinterpret_cast<Tin*>(src_cache.data_ptr()), \
  434. reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
  435. block_stride);
  436. void convert_fp8_e5m2(
  437. torch::Tensor& src_cache,
  438. torch::Tensor& dst_cache)
  439. {
  440. int64_t num_blocks = src_cache.size(0);
  441. int64_t block_stride = src_cache.stride(0);
  442. dim3 grid(num_blocks);
  443. dim3 block(std::min(block_stride, int64_t(512)));
  444. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  445. if (src_cache.dtype() == at::ScalarType::Float) {
  446. CALL_CONVERT_FP8_E5M2(uint8_t, float);
  447. } else if (src_cache.dtype() == at::ScalarType::Half) {
  448. CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
  449. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  450. CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
  451. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  452. CALL_CONVERT_FP8_E5M2(float, uint8_t);
  453. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  454. CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
  455. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  456. CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
  457. }
  458. }