common.h 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /*
  2. * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include <cassert>
  18. #include <cmath>
  19. #include <cstdint>
  20. #include <cuda_fp16.h>
  21. #if defined(ENABLE_BF16)
  22. #include <cuda_bf16.h>
  23. #endif
  24. #include <cuda_runtime.h>
  25. #include <cuda_runtime_api.h>
  26. #include <iostream>
  27. namespace tensorrt_llm
  28. {
  29. namespace kernels
  30. {
  31. enum class WeightOnlyQuantType
  32. {
  33. Int4b,
  34. Int8b
  35. };
  36. enum class WeightOnlyType
  37. {
  38. PerChannel,
  39. GroupWise
  40. };
  41. struct WeightOnlyPerChannel;
  42. template <int GS>
  43. struct WeightOnlyGroupWise;
  44. enum class WeightOnlyActivationFunctionType
  45. {
  46. Gelu,
  47. Relu,
  48. Identity,
  49. InvalidType
  50. };
  51. enum class WeightOnlyActivationType
  52. {
  53. FP16,
  54. BF16
  55. };
  56. struct WeightOnlyParams
  57. {
  58. // ActType is fp16 or bf16
  59. using ActType = void;
  60. using WeiType = uint8_t;
  61. const uint8_t* qweight;
  62. const ActType* scales;
  63. const ActType* zeros;
  64. const ActType* in;
  65. const ActType* act_scale;
  66. const ActType* bias;
  67. ActType* out;
  68. const int m;
  69. const int n;
  70. const int k;
  71. const int group_size;
  72. WeightOnlyQuantType quant_type;
  73. WeightOnlyType weight_only_type;
  74. WeightOnlyActivationFunctionType act_func_type;
  75. WeightOnlyActivationType act_type;
  76. WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in,
  77. const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k,
  78. const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
  79. const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type)
  80. : qweight(_qweight)
  81. , scales(_scales)
  82. , zeros(_zeros)
  83. , in(_in)
  84. , act_scale(_act_scale)
  85. , bias(_bias)
  86. , out(_out)
  87. , m(_m)
  88. , n(_n)
  89. , k(_k)
  90. , group_size(_group_size)
  91. , quant_type(_quant_type)
  92. , weight_only_type(_weight_only_type)
  93. , act_func_type(_act_func_type)
  94. , act_type(_act_type)
  95. {
  96. }
  97. };
  98. } // namespace kernels
  99. } // namespace tensorrt_llm