1
0

format.h 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include <cstdint>
  19. #include <cuda_fp16.h>
  20. #include <cuda_runtime.h>
  21. namespace aphrodite {
  22. namespace autoquant {
  23. void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
  24. void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
  25. void convert_s4_k_m8(uint32_t* A_dst,
  26. half2* Q_dst,
  27. half* workspace,
  28. const uint32_t* A_src,
  29. const half* scales,
  30. const uint32_t* qzeros,
  31. int m,
  32. int k,
  33. int group_size,
  34. cudaStream_t st = {});
  35. void convert_s4_k_m8(uint32_t* A_dst,
  36. __nv_bfloat162* Q_dst,
  37. __nv_bfloat16* workspace,
  38. const uint32_t* A_src,
  39. const __nv_bfloat16* scales,
  40. const uint32_t* qzeros,
  41. int m,
  42. int k,
  43. int group_size,
  44. cudaStream_t st = {});
  45. void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st = {});
  46. void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
  47. void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st = {});
  48. } // namespace autoquant
  49. } // namespace vllm