cache.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #include <map>
  2. #include <vector>
  3. #include "cpu_types.hpp"
  4. namespace {
  5. template <typename scalar_t>
  6. void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
  7. std::vector<torch::Tensor> const& value_caches,
  8. const torch::Tensor& mapping_pairs,
  9. const int element_num_per_block,
  10. const int layer_num) {
  11. const size_t pair_num = mapping_pairs.size(0);
  12. const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
  13. #pragma omp parallel for collapse(2)
  14. for (int layer = 0; layer < layer_num; ++layer) {
  15. for (size_t pair = 0; pair < pair_num; ++pair) {
  16. int64_t source_offset =
  17. element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
  18. int64_t target_offset =
  19. element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
  20. scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
  21. scalar_t* source_ptr = key_cache_ptr + source_offset;
  22. scalar_t* target_ptr = key_cache_ptr + target_offset;
  23. std::memcpy(target_ptr, source_ptr, block_bytes);
  24. scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
  25. source_ptr = value_cache_ptr + source_offset;
  26. target_ptr = value_cache_ptr + target_offset;
  27. std::memcpy(target_ptr, source_ptr, block_bytes);
  28. }
  29. }
  30. }
  31. template <typename scalar_t>
  32. void reshape_and_cache_cpu_impl(
  33. const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
  34. scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
  35. const int64_t* __restrict__ slot_mapping, const int num_tokens,
  36. const int key_stride, const int value_stride, const int num_heads,
  37. const int head_size, const int block_size, const int x) {
  38. const int block_elem_num = num_heads * head_size * block_size;
  39. #pragma omp parallel for collapse(2)
  40. for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
  41. for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
  42. const int64_t slot_idx = slot_mapping[token_idx];
  43. if (slot_idx >= 0) {
  44. int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
  45. int src_value_head_idx =
  46. token_idx * value_stride + head_idx * head_size;
  47. const scalar_t* src_key_head_ptr = key + src_key_head_idx;
  48. const scalar_t* src_value_head_ptr = value + src_value_head_idx;
  49. const int64_t block_index = slot_idx / block_size;
  50. const int64_t block_offset = slot_idx % block_size;
  51. scalar_t* target_key_head_ptr = key_cache +
  52. block_elem_num * block_index +
  53. head_idx * block_size * head_size;
  54. scalar_t* target_value_head_ptr = value_cache +
  55. block_elem_num * block_index +
  56. head_idx * block_size * head_size;
  57. for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
  58. const int64_t target_offset =
  59. src_key_idx * block_size + block_offset * x;
  60. for (int i = 0; i < x; ++i) {
  61. target_key_head_ptr[target_offset + i] =
  62. src_key_head_ptr[src_key_idx + i];
  63. }
  64. }
  65. for (int src_value_idx = 0; src_value_idx < head_size;
  66. ++src_value_idx) {
  67. const int64_t target_offset =
  68. src_value_idx * block_size + block_offset;
  69. target_value_head_ptr[target_offset] =
  70. src_value_head_ptr[src_value_idx];
  71. }
  72. }
  73. }
  74. }
  75. }
  76. }; // namespace
  77. // NOTE: the key_caches and value_caches vectors are constant but
  78. // not the Tensors they contain. The vectors need to be const refs
  79. // in order to satisfy pytorch's C++ operator registration code.
  80. void copy_blocks(std::vector<torch::Tensor> const& key_caches,
  81. std::vector<torch::Tensor> const& value_caches,
  82. const torch::Tensor& block_mapping) {
  83. unsigned num_layers = key_caches.size();
  84. TORCH_CHECK(num_layers == value_caches.size());
  85. if (num_layers == 0) {
  86. return;
  87. }
  88. const int element_num_per_block = key_caches[0][0].numel();
  89. APHRODITE_DISPATCH_FLOATING_TYPES(
  90. key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
  91. CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
  92. copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
  93. element_num_per_block, num_layers);
  94. CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
  95. });
  96. }
  97. void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
  98. torch::Tensor& key_cache, torch::Tensor& value_cache,
  99. torch::Tensor& slot_mapping,
  100. const std::string& kv_cache_dtype, double k_scale,
  101. double v_scale) {
  102. TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
  103. int num_tokens = key.size(0);
  104. int num_heads = key.size(1);
  105. int head_size = key.size(2);
  106. int block_size = key_cache.size(3);
  107. int x = key_cache.size(4);
  108. int key_stride = key.stride(0);
  109. int value_stride = value.stride(0);
  110. APHRODITE_DISPATCH_FLOATING_TYPES(
  111. key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
  112. CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
  113. reshape_and_cache_cpu_impl<scalar_t>(
  114. key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
  115. key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
  116. slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
  117. value_stride, num_heads, head_size, block_size, x);
  118. CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
  119. });
  120. }
  121. void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
  122. const torch::Tensor& block_mapping) {
  123. TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
  124. }