cache_kernels.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. #include <algorithm>
  7. #include <cassert>
  8. #include <map>
  9. #include <vector>
  10. void swap_blocks(
  11. torch::Tensor& src,
  12. torch::Tensor& dst,
  13. const std::map<int64_t, int64_t>& block_mapping) {
  14. torch::Device src_device = src.device();
  15. torch::Device dst_device = dst.device();
  16. cudaMemcpyKind memcpy_type;
  17. if (src_device.is_cuda() && dst_device.is_cuda()) {
  18. TORCH_CHECK(
  19. src_device.index() == dst_device.index(),
  20. "src and dst must be on the same GPU");
  21. memcpy_type = cudaMemcpyDeviceToDevice;
  22. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  23. memcpy_type = cudaMemcpyDeviceToHost;
  24. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  25. memcpy_type = cudaMemcpyHostToDevice;
  26. } else {
  27. TORCH_CHECK(false, "Invalid device combination");
  28. }
  29. char *src_ptr = static_cast<char*>(src.data_ptr());
  30. char *dst_ptr = static_cast<char*>(dst.data_ptr());
  31. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  32. const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
  33. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  34. // NOTE: This can be slow if the number of blocks is large.
  35. for (const auto& pair : block_mapping) {
  36. int64_t src_block_number = pair.first;
  37. int64_t dst_block_number = pair.second;
  38. int64_t src_offset = src_block_number * block_size_in_bytes;
  39. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  40. cudaMemcpyAsync(
  41. dst_ptr + dst_offset,
  42. src_ptr + src_offset,
  43. block_size_in_bytes,
  44. memcpy_type,
  45. stream);
  46. }
  47. }
  48. namespace aphrodite {
  49. // Grid: (num_layers, num_pairs)
  50. template<typename scalar_t>
  51. __global__ void copy_blocks_kernel(
  52. int64_t* key_cache_ptrs,
  53. int64_t* value_cache_ptrs,
  54. const int64_t* __restrict__ block_mapping,
  55. const int numel_per_block) {
  56. const int layer_idx = blockIdx.x;
  57. const int pair_idx = blockIdx.y;
  58. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  59. scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  60. int64_t src_block_number = block_mapping[2 * pair_idx];
  61. int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
  62. const int64_t src_block_offset = src_block_number * numel_per_block;
  63. const int64_t dst_block_offset = dst_block_number * numel_per_block;
  64. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  65. int64_t src_offset = src_block_offset + i;
  66. int64_t dst_offset = dst_block_offset + i;
  67. key_cache[dst_offset] = key_cache[src_offset];
  68. }
  69. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  70. int64_t src_offset = src_block_offset + i;
  71. int64_t dst_offset = dst_block_offset + i;
  72. value_cache[dst_offset] = value_cache[src_offset];
  73. }
  74. }
  75. } // namespace aphrodite
  76. void copy_blocks(
  77. std::vector<torch::Tensor>& key_caches,
  78. std::vector<torch::Tensor>& value_caches,
  79. const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
  80. int num_layers = key_caches.size();
  81. TORCH_CHECK(num_layers == value_caches.size());
  82. if (num_layers == 0) {
  83. return;
  84. }
  85. torch::Device cache_device = key_caches[0].device();
  86. TORCH_CHECK(cache_device.is_cuda());
  87. // Create data structures for the kernel.
  88. // Create an array of pointers to the key and value caches.
  89. int64_t key_cache_ptrs[num_layers];
  90. int64_t value_cache_ptrs[num_layers];
  91. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  92. key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  93. value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  94. }
  95. // Create block mapping array.
  96. std::vector<int64_t> block_mapping_vec;
  97. for (const auto& pair : block_mapping) {
  98. int64_t src_block_number = pair.first;
  99. for (int64_t dst_block_number : pair.second) {
  100. block_mapping_vec.push_back(src_block_number);
  101. block_mapping_vec.push_back(dst_block_number);
  102. }
  103. }
  104. int64_t* block_mapping_array = block_mapping_vec.data();
  105. int num_pairs = block_mapping_vec.size() / 2;
  106. // Move the data structures to the GPU.
  107. // NOTE: This synchronizes the CPU and GPU.
  108. torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
  109. key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  110. torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
  111. value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  112. torch::Tensor block_mapping_tensor = torch::from_blob(
  113. block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
  114. // Launch the kernel.
  115. const int numel_per_block = key_caches[0][0].numel();
  116. dim3 grid(num_layers, num_pairs);
  117. dim3 block(std::min(1024, numel_per_block));
  118. const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  119. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  120. APHRODITE_DISPATCH_FLOATING_TYPES(
  121. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  122. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  123. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  124. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  125. block_mapping_tensor.data_ptr<int64_t>(),
  126. numel_per_block);
  127. }));
  128. }
  129. namespace aphrodite {
  130. template<typename scalar_t>
  131. __global__ void reshape_and_cache_kernel(
  132. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  133. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  134. scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  135. scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  136. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  137. const int key_stride,
  138. const int value_stride,
  139. const int num_heads,
  140. const int head_size,
  141. const int block_size,
  142. const int x) {
  143. const int64_t token_idx = blockIdx.x;
  144. const int64_t slot_idx = slot_mapping[token_idx];
  145. if (slot_idx < 0) {
  146. // Padding token that should be ignored.
  147. return;
  148. }
  149. const int64_t block_idx = slot_idx / block_size;
  150. const int64_t block_offset = slot_idx % block_size;
  151. const int n = num_heads * head_size;
  152. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  153. const int64_t src_key_idx = token_idx * key_stride + i;
  154. const int64_t src_value_idx = token_idx * value_stride + i;
  155. const int head_idx = i / head_size;
  156. const int head_offset = i % head_size;
  157. const int x_idx = head_offset / x;
  158. const int x_offset = head_offset % x;
  159. const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  160. + head_idx * (head_size / x) * block_size * x
  161. + x_idx * block_size * x
  162. + block_offset * x
  163. + x_offset;
  164. const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
  165. + head_idx * head_size * block_size
  166. + head_offset * block_size
  167. + block_offset;
  168. key_cache[tgt_key_idx] = key[src_key_idx];
  169. value_cache[tgt_value_idx] = value[src_value_idx];
  170. }
  171. }
  172. } // namespace aphrodite
  173. void reshape_and_cache(
  174. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  175. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  176. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  177. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  178. torch::Tensor& slot_mapping) // [num_tokens]
  179. {
  180. int num_tokens = key.size(0);
  181. int num_heads = key.size(1);
  182. int head_size = key.size(2);
  183. int block_size = key_cache.size(3);
  184. int x = key_cache.size(4);
  185. int key_stride = key.stride(0);
  186. int value_stride = value.stride(0);
  187. dim3 grid(num_tokens);
  188. dim3 block(std::min(num_heads * head_size, 512));
  189. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  190. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  191. APHRODITE_DISPATCH_FLOATING_TYPES(
  192. key.scalar_type(),
  193. "reshape_and_cache_kernel",
  194. [&] {
  195. aphrodite::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
  196. key.data_ptr<scalar_t>(),
  197. value.data_ptr<scalar_t>(),
  198. key_cache.data_ptr<scalar_t>(),
  199. value_cache.data_ptr<scalar_t>(),
  200. slot_mapping.data_ptr<int64_t>(),
  201. key_stride,
  202. value_stride,
  203. num_heads,
  204. head_size,
  205. block_size,
  206. x);
  207. });
  208. }
  209. namespace aphrodite {
  210. // Grid: (num_blocks, block_size).
  211. template<typename scalar_t>
  212. __global__ void gather_cached_kv_kernel(
  213. scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  214. scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  215. const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  216. const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  217. const int* __restrict__ slot_mapping, // [num_tokens]
  218. const int key_stride,
  219. const int value_stride,
  220. const int num_heads,
  221. const int head_size,
  222. const int block_size,
  223. const int x) {
  224. const int token_idx = blockIdx.x;
  225. const int slot_idx = slot_mapping[token_idx];
  226. const int block_idx = slot_idx / block_size;
  227. const int block_offset = slot_idx % block_size;
  228. const int num_tokens = num_heads * head_size;
  229. for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
  230. const int tgt_key_idx = token_idx * key_stride + i;
  231. const int tgt_value_idx = token_idx * value_stride + i;
  232. const int head_idx = i / head_size;
  233. const int head_offset = i % head_size;
  234. const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
  235. const int x_offset = head_offset % x;
  236. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  237. + head_idx * (head_size / x) * block_size * x
  238. + x_idx * block_size * x
  239. + block_offset * x
  240. + x_offset;
  241. const int src_value_idx = block_idx * num_heads * head_size * block_size
  242. + head_idx * head_size * block_size
  243. + head_offset * block_size
  244. + block_offset;
  245. key[tgt_key_idx] = APHRODITE_LDG(&key_cache[src_key_idx]);
  246. value[tgt_value_idx] = APHRODITE_LDG(&value_cache[src_value_idx]);
  247. }
  248. }
  249. template <typename scalar_t>
  250. __global__ void gather_cached_kv_kernel_optimized(
  251. scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  252. scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  253. const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  254. const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  255. const int *__restrict__ slot_mapping, // [num_tokens]
  256. const int key_stride,
  257. const int value_stride,
  258. const int num_heads,
  259. const int head_size,
  260. const int block_size,
  261. const int x)
  262. {
  263. const int token_idx = blockIdx.x;
  264. const int slot_idx = slot_mapping[token_idx];
  265. const int block_idx = slot_idx / block_size;
  266. const int block_offset = slot_idx % block_size;
  267. const int dim = num_heads * head_size;
  268. assert(dim % 4 == 0); // this is true for known use cases
  269. const int unroll_factor = 4;
  270. const int unrolled_dim = dim / unroll_factor;
  271. for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
  272. {
  273. int tgt_key_indices[unroll_factor];
  274. int tgt_value_indices[unroll_factor];
  275. int src_key_indices[unroll_factor];
  276. int src_value_indices[unroll_factor];
  277. scalar_t keys_to_store[unroll_factor];
  278. scalar_t values_to_store[unroll_factor];
  279. #pragma unroll
  280. for (int j = 0; j < unroll_factor; ++j)
  281. {
  282. int index = i + j * unrolled_dim;
  283. const int tgt_key_idx = token_idx * key_stride + index;
  284. const int tgt_value_idx = token_idx * value_stride + index;
  285. const int head_idx = index / head_size;
  286. const int head_offset = index % head_size;
  287. const int x_idx = head_offset / x;
  288. const int x_offset = head_offset % x;
  289. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  290. + head_idx * (head_size / x) * block_size * x
  291. + x_idx * block_size * x
  292. + block_offset * x
  293. + x_offset;
  294. const int src_value_idx = block_idx * num_heads * head_size * block_size
  295. + head_idx * head_size * block_size
  296. + head_offset * block_size
  297. + block_offset;
  298. tgt_key_indices[j] = tgt_key_idx;
  299. tgt_value_indices[j] = tgt_value_idx;
  300. src_key_indices[j] = src_key_idx;
  301. src_value_indices[j] = src_value_idx;
  302. keys_to_store[j] = APHRODITE_LDG(&key_cache[src_key_idx]);
  303. values_to_store[j] = APHRODITE_LDG(&value_cache[src_value_idx]);
  304. }
  305. #pragma unroll
  306. for (int j = 0; j < unroll_factor; ++j)
  307. {
  308. key[tgt_key_indices[j]] = keys_to_store[j];
  309. value[tgt_value_indices[j]] = values_to_store[j];
  310. }
  311. }
  312. }
  313. } // namespace aphrodite
  314. void gather_cached_kv(
  315. torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
  316. torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
  317. torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
  318. torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
  319. torch::Tensor& slot_mapping) // [in] [num_tokens]
  320. {
  321. int num_tokens = key.size(0);
  322. int num_heads = key.size(1);
  323. int head_size = key.size(2);
  324. int block_size = key_cache.size(3);
  325. int x = key_cache.size(4);
  326. int key_stride = key.stride(0);
  327. int value_stride = value.stride(0);
  328. dim3 grid(num_tokens);
  329. dim3 block(std::min(num_heads * head_size, 512));
  330. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  331. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  332. APHRODITE_DISPATCH_FLOATING_TYPES(
  333. key.scalar_type(),
  334. "gather_cached_kv_kernel_optimized",
  335. [&] {
  336. aphrodite::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
  337. key.data_ptr<scalar_t>(),
  338. value.data_ptr<scalar_t>(),
  339. key_cache.data_ptr<scalar_t>(),
  340. value_cache.data_ptr<scalar_t>(),
  341. slot_mapping.data_ptr<int>(),
  342. key_stride,
  343. value_stride,
  344. num_heads,
  345. head_size,
  346. block_size,
  347. x);
  348. });
  349. }