cache_kernels.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include "dispatch_utils.h"
  4. #include <algorithm>
  5. #include <cassert>
  6. #include <map>
  7. #include <vector>
  8. void swap_blocks(
  9. torch::Tensor& src,
  10. torch::Tensor& dst,
  11. const std::map<int64_t, int64_t>& block_mapping) {
  12. torch::Device src_device = src.device();
  13. torch::Device dst_device = dst.device();
  14. cudaMemcpyKind memcpy_type;
  15. if (src_device.is_cuda() && dst_device.is_cuda()) {
  16. TORCH_CHECK(
  17. src_device.index() == dst_device.index(),
  18. "src and dst must be on the same GPU");
  19. memcpy_type = cudaMemcpyDeviceToDevice;
  20. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  21. memcpy_type = cudaMemcpyDeviceToHost;
  22. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  23. memcpy_type = cudaMemcpyHostToDevice;
  24. } else {
  25. TORCH_CHECK(false, "Invalid device combination");
  26. }
  27. void *src_ptr = src.data_ptr();
  28. void *dst_ptr = dst.data_ptr();
  29. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  30. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  31. // NOTE: This can be slow if the number of blocks is large.
  32. for (const auto& pair : block_mapping) {
  33. int64_t src_block_number = pair.first;
  34. int64_t dst_block_number = pair.second;
  35. int64_t src_offset = src_block_number * block_size_in_bytes;
  36. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  37. cudaMemcpyAsync(
  38. dst_ptr + dst_offset,
  39. src_ptr + src_offset,
  40. block_size_in_bytes,
  41. memcpy_type,
  42. stream);
  43. }
  44. }
  45. namespace aphrodite {
  46. // Grid: (num_layers, num_pairs)
  47. template<typename scalar_t>
  48. __global__ void copy_blocks_kernel(
  49. int64_t* key_cache_ptrs,
  50. int64_t* value_cache_ptrs,
  51. const int* __restrict__ block_mapping,
  52. const int numel_per_block) {
  53. const int layer_idx = blockIdx.x;
  54. const int pair_idx = blockIdx.y;
  55. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  56. scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  57. int src_block_number = block_mapping[2 * pair_idx];
  58. int dst_block_number = block_mapping[2 * pair_idx + 1];
  59. const int src_block_offset = src_block_number * numel_per_block;
  60. const int dst_block_offset = dst_block_number * numel_per_block;
  61. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  62. int src_offset = src_block_offset + i;
  63. int dst_offset = dst_block_offset + i;
  64. key_cache[dst_offset] = key_cache[src_offset];
  65. }
  66. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  67. int src_offset = src_block_offset + i;
  68. int dst_offset = dst_block_offset + i;
  69. value_cache[dst_offset] = value_cache[src_offset];
  70. }
  71. }
  72. } // namespace aphrodite
  73. void copy_blocks(
  74. std::vector<torch::Tensor>& key_caches,
  75. std::vector<torch::Tensor>& value_caches,
  76. const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
  77. int num_layers = key_caches.size();
  78. TORCH_CHECK(num_layers == value_caches.size());
  79. if (num_layers == 0) {
  80. return;
  81. }
  82. torch::Device cache_device = key_caches[0].device();
  83. TORCH_CHECK(cache_device.is_cuda());
  84. // Create data structures for the kernel.
  85. // Create an array of pointers to the key and value caches.
  86. int64_t key_cache_ptrs[num_layers];
  87. int64_t value_cache_ptrs[num_layers];
  88. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  89. key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  90. value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  91. }
  92. // Create block mapping array.
  93. std::vector<int> block_mapping_vec;
  94. for (const auto& pair : block_mapping) {
  95. int src_block_number = pair.first;
  96. for (int dst_block_number : pair.second) {
  97. block_mapping_vec.push_back(src_block_number);
  98. block_mapping_vec.push_back(dst_block_number);
  99. }
  100. }
  101. int* block_mapping_array = block_mapping_vec.data();
  102. int num_pairs = block_mapping_vec.size() / 2;
  103. // Move the data structures to the GPU.
  104. // NOTE: This synchronizes the CPU and GPU.
  105. torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
  106. key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  107. torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
  108. value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  109. torch::Tensor block_mapping_tensor = torch::from_blob(
  110. block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);
  111. // Launch the kernel.
  112. const int numel_per_block = key_caches[0][0].numel();
  113. dim3 grid(num_layers, num_pairs);
  114. dim3 block(std::min(1024, numel_per_block));
  115. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  116. APHRODITE_DISPATCH_FLOATING_TYPES(
  117. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  118. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  119. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  120. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  121. block_mapping_tensor.data_ptr<int>(),
  122. numel_per_block);
  123. }));
  124. }
  125. namespace aphrodite {
  126. template<typename scalar_t>
  127. __global__ void reshape_and_cache_kernel(
  128. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  129. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  130. scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  131. scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  132. const int* __restrict__ slot_mapping, // [num_tokens]
  133. const int key_stride,
  134. const int value_stride,
  135. const int num_heads,
  136. const int head_size,
  137. const int block_size,
  138. const int x) {
  139. const int token_idx = blockIdx.x;
  140. const int slot_idx = slot_mapping[token_idx];
  141. const int block_idx = slot_idx / block_size;
  142. const int block_offset = slot_idx % block_size;
  143. const int n = num_heads * head_size;
  144. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  145. const int src_key_idx = token_idx * key_stride + i;
  146. const int src_value_idx = token_idx * value_stride + i;
  147. const int head_idx = i / head_size;
  148. const int head_offset = i % head_size;
  149. const int x_idx = head_offset / x;
  150. const int x_offset = head_offset % x;
  151. const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  152. + head_idx * (head_size / x) * block_size * x
  153. + x_idx * block_size * x
  154. + block_offset * x
  155. + x_offset;
  156. const int tgt_value_idx = block_idx * num_heads * head_size * block_size
  157. + head_idx * head_size * block_size
  158. + head_offset * block_size
  159. + block_offset;
  160. key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
  161. value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
  162. }
  163. }
  164. } // namespace aphrodite
  165. void reshape_and_cache(
  166. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  167. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  168. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  169. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  170. torch::Tensor& slot_mapping) // [num_tokens]
  171. {
  172. int num_tokens = key.size(0);
  173. int num_heads = key.size(1);
  174. int head_size = key.size(2);
  175. int block_size = key_cache.size(3);
  176. int x = key_cache.size(4);
  177. int key_stride = key.stride(0);
  178. int value_stride = value.stride(0);
  179. dim3 grid(num_tokens);
  180. dim3 block(std::min(num_heads * head_size, 512));
  181. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  182. APHRODITE_DISPATCH_FLOATING_TYPES(
  183. key.scalar_type(),
  184. "reshape_and_cache_kernel",
  185. [&] {
  186. aphrodite::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
  187. key.data_ptr<scalar_t>(),
  188. value.data_ptr<scalar_t>(),
  189. key_cache.data_ptr<scalar_t>(),
  190. value_cache.data_ptr<scalar_t>(),
  191. slot_mapping.data_ptr<int>(),
  192. key_stride,
  193. value_stride,
  194. num_heads,
  195. head_size,
  196. block_size,
  197. x);
  198. });
  199. }
  200. namespace aphrodite {
  201. // Grid: (num_blocks, block_size).
  202. template<typename scalar_t>
  203. __global__ void gather_cached_kv_kernel(
  204. scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  205. scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  206. const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  207. const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  208. const int* __restrict__ slot_mapping, // [num_tokens]
  209. const int key_stride,
  210. const int value_stride,
  211. const int num_heads,
  212. const int head_size,
  213. const int block_size,
  214. const int x) {
  215. const int token_idx = blockIdx.x;
  216. const int slot_idx = slot_mapping[token_idx];
  217. const int block_idx = slot_idx / block_size;
  218. const int block_offset = slot_idx % block_size;
  219. const int num_tokens = num_heads * head_size;
  220. for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
  221. const int tgt_key_idx = token_idx * key_stride + i;
  222. const int tgt_value_idx = token_idx * value_stride + i;
  223. const int head_idx = i / head_size;
  224. const int head_offset = i % head_size;
  225. const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
  226. const int x_offset = head_offset % x;
  227. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  228. + head_idx * (head_size / x) * block_size * x
  229. + x_idx * block_size * x
  230. + block_offset * x
  231. + x_offset;
  232. const int src_value_idx = block_idx * num_heads * head_size * block_size
  233. + head_idx * head_size * block_size
  234. + head_offset * block_size
  235. + block_offset;
  236. key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
  237. value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
  238. }
  239. }
  240. template <typename scalar_t>
  241. __global__ void gather_cached_kv_kernel_optimized(
  242. scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  243. scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  244. const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  245. const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  246. const int *__restrict__ slot_mapping, // [num_tokens]
  247. const int key_stride,
  248. const int value_stride,
  249. const int num_heads,
  250. const int head_size,
  251. const int block_size,
  252. const int x)
  253. {
  254. const int token_idx = blockIdx.x;
  255. const int slot_idx = slot_mapping[token_idx];
  256. const int block_idx = slot_idx / block_size;
  257. const int block_offset = slot_idx % block_size;
  258. const int dim = num_heads * head_size;
  259. assert(dim % 4 == 0); // this is true for known use cases
  260. const int unroll_factor = 4;
  261. const int unrolled_dim = dim / unroll_factor;
  262. for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
  263. {
  264. int tgt_key_indices[unroll_factor];
  265. int tgt_value_indices[unroll_factor];
  266. int src_key_indices[unroll_factor];
  267. int src_value_indices[unroll_factor];
  268. scalar_t keys_to_store[unroll_factor];
  269. scalar_t values_to_store[unroll_factor];
  270. #pragma unroll
  271. for (int j = 0; j < unroll_factor; ++j)
  272. {
  273. int index = i + j * unrolled_dim;
  274. const int tgt_key_idx = token_idx * key_stride + index;
  275. const int tgt_value_idx = token_idx * value_stride + index;
  276. const int head_idx = index / head_size;
  277. const int head_offset = index % head_size;
  278. const int x_idx = head_offset / x;
  279. const int x_offset = head_offset % x;
  280. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  281. + head_idx * (head_size / x) * block_size * x
  282. + x_idx * block_size * x
  283. + block_offset * x
  284. + x_offset;
  285. const int src_value_idx = block_idx * num_heads * head_size * block_size
  286. + head_idx * head_size * block_size
  287. + head_offset * block_size
  288. + block_offset;
  289. tgt_key_indices[j] = tgt_key_idx;
  290. tgt_value_indices[j] = tgt_value_idx;
  291. src_key_indices[j] = src_key_idx;
  292. src_value_indices[j] = src_value_idx;
  293. keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
  294. values_to_store[j] = __ldg(&value_cache[src_value_idx]);
  295. }
  296. #pragma unroll
  297. for (int j = 0; j < unroll_factor; ++j)
  298. {
  299. key[tgt_key_indices[j]] = keys_to_store[j];
  300. value[tgt_value_indices[j]] = values_to_store[j];
  301. }
  302. }
  303. }
  304. } // namespace aphrodite
  305. void gather_cached_kv(
  306. torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
  307. torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
  308. torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
  309. torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
  310. torch::Tensor& slot_mapping) // [in] [num_tokens]
  311. {
  312. int num_tokens = key.size(0);
  313. int num_heads = key.size(1);
  314. int head_size = key.size(2);
  315. int block_size = key_cache.size(3);
  316. int x = key_cache.size(4);
  317. int key_stride = key.stride(0);
  318. int value_stride = value.stride(0);
  319. dim3 grid(num_tokens);
  320. dim3 block(std::min(num_heads * head_size, 512));
  321. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  322. APHRODITE_DISPATCH_FLOATING_TYPES(
  323. key.scalar_type(),
  324. "gather_cached_kv_kernel_optimized",
  325. [&] {
  326. aphrodite::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
  327. key.data_ptr<scalar_t>(),
  328. value.data_ptr<scalar_t>(),
  329. key_cache.data_ptr<scalar_t>(),
  330. value_cache.data_ptr<scalar_t>(),
  331. slot_mapping.data_ptr<int>(),
  332. key_stride,
  333. value_stride,
  334. num_heads,
  335. head_size,
  336. block_size,
  337. x);
  338. });
  339. }