xqa_kernel_launcher.cuh 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. /*
  2. * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include <torch/all.h>
  18. #include <cuda_fp16.h>
  19. #include <cuda_runtime.h>
  20. #include "decoder_xqa_impl_common.h"
  21. bool const isPerfsim = []() {
  22. auto const v = std::getenv("XQA_IS_PERFSIM");
  23. if (!v) {
  24. return false;
  25. }
  26. return bool(std::stoi(v));
  27. }();
  28. template <typename T>
  29. class ManagedMemBuf
  30. {
  31. public:
  32. ManagedMemBuf(size_t nbElems): mSize {nbElems} {
  33. if (nbElems != 0) {
  34. void* p;
  35. checkCuda(cudaMallocManaged(&p, sizeof(T) * nbElems));
  36. mData.reset(reinterpret_cast<T*>(p));
  37. }
  38. }
  39. T* get() const {return mData.get();}
  40. size_t size() const {return mSize;}
  41. void prefetch(int dstDevice, cudaStream_t stream = nullptr) const {
  42. if (!isPerfsim) {
  43. checkCuda(cudaMemPrefetchAsync(get(), sizeof(T) * size(), dstDevice, stream));
  44. }
  45. }
  46. T& operator[](size_t i) const {
  47. return mData[i];
  48. };
  49. private:
  50. struct CudaDeleter
  51. {
  52. void operator()(void *p) const {
  53. cudaFree(p);
  54. }
  55. };
  56. std::unique_ptr<T[], CudaDeleter> mData;
  57. size_t mSize;
  58. };
  59. void xqa_paged_attention(
  60. torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_value_cache,
  61. int64_t num_heads, int64_t num_kv_heads, int64_t rotary_embedding_dim, double scale,
  62. torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
  63. int64_t max_seq_len, const std::string kv_cache_dtype, double k_scale, double v_scale);