cache.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #include <map>
  2. #include <vector>
  3. #include "cpu_types.hpp"
  4. namespace {
  5. template <typename scalar_t>
  6. void copy_blocks_cpu_impl(
  7. std::vector<torch::Tensor> &key_caches,
  8. std::vector<torch::Tensor> &value_caches,
  9. const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
  10. const int element_num_per_block, const int layer_num) {
  11. const size_t pair_num = mapping_pairs.size();
  12. const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
  13. #pragma omp parallel for collapse(2)
  14. for (int layer = 0; layer < layer_num; ++layer) {
  15. for (size_t pair = 0; pair < pair_num; ++pair) {
  16. int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
  17. int64_t target_offset =
  18. element_num_per_block * mapping_pairs[pair].second;
  19. scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
  20. scalar_t *source_ptr = key_cache_ptr + source_offset;
  21. scalar_t *target_ptr = key_cache_ptr + target_offset;
  22. std::memcpy(target_ptr, source_ptr, block_bytes);
  23. scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
  24. source_ptr = value_cache_ptr + source_offset;
  25. target_ptr = value_cache_ptr + target_offset;
  26. std::memcpy(target_ptr, source_ptr, block_bytes);
  27. }
  28. }
  29. }
  30. template <typename scalar_t>
  31. void reshape_and_cache_cpu_impl(
  32. const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
  33. scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
  34. const int64_t *__restrict__ slot_mapping, const int num_tokens,
  35. const int key_stride, const int value_stride, const int num_heads,
  36. const int head_size, const int block_size, const int x) {
  37. const int block_elem_num = num_heads * head_size * block_size;
  38. #pragma omp parallel for collapse(2)
  39. for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
  40. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  41. const int64_t slot_idx = slot_mapping[token_idx];
  42. if (slot_idx >= 0) {
  43. int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
  44. int src_value_head_idx =
  45. token_idx * value_stride + head_idx * head_size;
  46. const scalar_t *src_key_head_ptr = key + src_key_head_idx;
  47. const scalar_t *src_value_head_ptr = value + src_value_head_idx;
  48. const int64_t block_index = slot_idx / block_size;
  49. const int64_t block_offset = slot_idx % block_size;
  50. scalar_t *target_key_head_ptr = key_cache +
  51. block_elem_num * block_index +
  52. head_idx * block_size * head_size;
  53. scalar_t *target_value_head_ptr = value_cache +
  54. block_elem_num * block_index +
  55. head_idx * block_size * head_size;
  56. for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
  57. const int64_t target_offset =
  58. src_key_idx * block_size + block_offset * x;
  59. for (int i = 0; i < x; ++i) {
  60. target_key_head_ptr[target_offset + i] =
  61. src_key_head_ptr[src_key_idx + i];
  62. }
  63. }
  64. for (int src_value_idx = 0; src_value_idx < head_size;
  65. ++src_value_idx) {
  66. const int64_t target_offset =
  67. src_value_idx * block_size + block_offset;
  68. target_value_head_ptr[target_offset] =
  69. src_value_head_ptr[src_value_idx];
  70. }
  71. }
  72. }
  73. }
  74. }
  75. }; // namespace
  76. void copy_blocks(std::vector<torch::Tensor> &key_caches,
  77. std::vector<torch::Tensor> &value_caches,
  78. const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
  79. int num_layers = key_caches.size();
  80. TORCH_CHECK(num_layers == value_caches.size());
  81. if (num_layers == 0) {
  82. return;
  83. }
  84. std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
  85. mapping_pairs.reserve(block_mapping.size());
  86. for (const auto &pair : block_mapping) {
  87. for (const auto &dst : pair.second) {
  88. mapping_pairs.emplace_back(pair.first, dst);
  89. }
  90. }
  91. const int element_num_per_block = key_caches[0][0].numel();
  92. APHRODITE_DISPATCH_FLOATING_TYPES(
  93. key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
  94. CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
  95. copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
  96. element_num_per_block, num_layers);
  97. CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
  98. });
  99. }
  100. void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
  101. torch::Tensor &key_cache, torch::Tensor &value_cache,
  102. torch::Tensor &slot_mapping,
  103. const std::string &kv_cache_dtype, float kv_scale) {
  104. TORCH_CHECK(kv_scale == 1.0f);
  105. int num_tokens = key.size(0);
  106. int num_heads = key.size(1);
  107. int head_size = key.size(2);
  108. int block_size = key_cache.size(3);
  109. int x = key_cache.size(4);
  110. int key_stride = key.stride(0);
  111. int value_stride = value.stride(0);
  112. APHRODITE_DISPATCH_FLOATING_TYPES(
  113. key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
  114. CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
  115. reshape_and_cache_cpu_impl<scalar_t>(
  116. key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
  117. key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
  118. slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
  119. value_stride, num_heads, head_size, block_size, x);
  120. CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
  121. });
  122. }
  123. void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
  124. const std::map<int64_t, int64_t> &block_mapping) {
  125. TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
  126. }