1
0

format.h 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include <cstdint>
  19. #include <cuda_fp16.h>
  20. #include <cuda_runtime.h>
  21. namespace aphrodite {
  22. namespace autoquant {
  23. void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k,
  24. cudaStream_t st = {});
  25. void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k,
  26. cudaStream_t st = {});
  27. void convert_s4_k_m8(uint32_t* A_dst, half2* Q_dst, half* workspace,
  28. const uint32_t* A_src, const half* scales,
  29. const uint32_t* qzeros, int m, int k, int group_size,
  30. cudaStream_t st = {});
  31. void convert_s4_k_m8(uint32_t* A_dst, __nv_bfloat162* Q_dst,
  32. __nv_bfloat16* workspace, const uint32_t* A_src,
  33. const __nv_bfloat16* scales, const uint32_t* qzeros, int m,
  34. int k, int group_size, cudaStream_t st = {});
  35. void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k,
  36. int size_per_head, cudaStream_t st = {});
  37. void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k,
  38. cudaStream_t st = {});
  39. void dequantize_s4(uint4* dst, const uint32_t* src, size_t count,
  40. cudaStream_t st = {});
  41. } // namespace autoquant
  42. } // namespace aphrodite