1
0

kernel.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. /*
  2. * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include "common.h"
  18. #include "utility.h"
  19. namespace tensorrt_llm
  20. {
  21. namespace kernels
  22. {
  23. template <typename ActType>
  24. struct ActTypeDetails;
  25. template <>
  26. struct ActTypeDetails<half>
  27. {
  28. using CutlassType = cutlass::half_t;
  29. using Vec2 = half2;
  30. __device__ __forceinline__ static Vec2 to_vec2(half v)
  31. {
  32. return __half2half2(v);
  33. }
  34. };
  35. #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
  36. template <>
  37. struct ActTypeDetails<__nv_bfloat16>
  38. {
  39. using CutlassType = cutlass::bfloat16_t;
  40. using Vec2 = __nv_bfloat162;
  41. __device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v)
  42. {
  43. return __bfloat162bfloat162(v);
  44. }
  45. };
  46. #endif
  47. template <typename ActType, WeightOnlyQuantType QType>
  48. struct ConverterSelector
  49. {
  50. static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b);
  51. using WeiType = std::conditional_t<QType == WeightOnlyQuantType::Int4b, cutlass::uint4b_t, uint8_t>;
  52. static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4;
  53. using Converter
  54. = cutlass::FastInterleavedAndBiasedNumericArrayConverter<typename ActTypeDetails<ActType>::CutlassType, WeiType,
  55. kConvertCount>;
  56. };
  57. template <typename ActType, WeightOnlyQuantType QType>
  58. struct WeightOnlyDetails;
  59. template <typename ActType>
  60. struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int4b>
  61. {
  62. // Every four rows of the original weights are interleaved into a row with stride of 64, so if each thread
  63. // processes 32 elements(for int4, we can use ldg.128 to load weights), then every group of two adjacent threads
  64. // will alternately process four different row weights
  65. // for example
  66. // every 256 consecutive int4 elements [256*i, 256*(i+1)-1] of row N under interleave layout,
  67. // the first 64 are from [64*i, 64*(i+1)-1] of row 4N before interleaving,
  68. // and the second 64 are from [64*i, 64*(i+1)-1] of row 4N+1 before interleaving, and so on.
  69. // So if each thread loads 32 int4 elements, then the elements of each 2 adjacent threads of each 8
  70. // consecutive threads will come from row 4N ~ 4N+3 respectively before interleaving.
  71. static constexpr int kElemBits = 4;
  72. static constexpr int kInterleave = 4;
  73. static constexpr int kStride = 64;
  74. // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm
  75. // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 31
  76. // weight 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
  77. static constexpr int kShuffleSize = 32;
  78. static constexpr int kShuffleBasicTile = 2;
  79. static constexpr int kShuffleContinous = 4;
  80. static constexpr int kShuffleStrided = 4;
  81. // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
  82. // corresponding address in shared memory
  83. template <int Num, int WarpSize>
  84. __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave])
  85. {
  86. #pragma unroll
  87. for (int i = 0; i < Num; ++i)
  88. {
  89. res[i] += __shfl_xor_sync(~0, res[i], 16);
  90. res[i] += __shfl_xor_sync(~0, res[i], 8);
  91. res[i] += __shfl_xor_sync(~0, res[i], 1);
  92. }
  93. __syncthreads();
  94. int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
  95. if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
  96. {
  97. #pragma unroll
  98. for (int i = 0; i < Num; ++i)
  99. {
  100. sm[warp][i * kInterleave + lane / 2] = res[i];
  101. }
  102. }
  103. __syncthreads();
  104. }
  105. };
  106. template <typename ActType>
  107. struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int8b>
  108. {
  109. // Every two rows of the original weights are interleaved into a row with stride of 64, so if each thread
  110. // processes 16 elements(for int8, we can use ldg.128 to load weights), then every group of four adjacent threads
  111. // will alternately process two different row weights
  112. // for example
  113. // every 128 consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave layout,
  114. // the first 64 are from [64*i, 64*(i+1)-1] of row 2N before interleaving,
  115. // and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 before interleaving.
  116. // So if each thread loads 16 int8 elements, then the elements of the first four and last four threads of each 8
  117. // consecutive threads will come from row 2N and row 2N+1 respectively before interleaving.
  118. static constexpr int kElemBits = 8;
  119. static constexpr int kInterleave = 2;
  120. static constexpr int kStride = 64;
  121. // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm
  122. // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
  123. // weight 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
  124. static constexpr int kShuffleSize = 16;
  125. static constexpr int kShuffleBasicTile = 2;
  126. static constexpr int kShuffleContinous = 2;
  127. static constexpr int kShuffleStrided = 4;
  128. // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
  129. // corresponding address in shared memory
  130. template <int Num, int WarpSize>
  131. __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave])
  132. {
  133. #pragma unroll
  134. for (int i = 0; i < Num; ++i)
  135. {
  136. res[i] += __shfl_xor_sync(~0, res[i], 16);
  137. res[i] += __shfl_xor_sync(~0, res[i], 8);
  138. res[i] += __shfl_xor_sync(~0, res[i], 2);
  139. res[i] += __shfl_xor_sync(~0, res[i], 1);
  140. }
  141. __syncthreads();
  142. int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
  143. if (lane == 0 || lane == 4)
  144. {
  145. #pragma unroll
  146. for (int i = 0; i < Num; ++i)
  147. {
  148. sm[warp][i * kInterleave + lane / 4] = res[i];
  149. }
  150. }
  151. __syncthreads();
  152. }
  153. };
  154. template <typename ActType, WeightOnlyQuantType QType>
  155. struct WeightOnlyKernelDetails
  156. {
  157. using Layout = WeightOnlyDetails<ActType, QType>;
  158. static constexpr int kElemBits = Layout::kElemBits;
  159. static constexpr int kInterleave = Layout::kInterleave;
  160. static constexpr int kStride = Layout::kStride;
  161. static constexpr int kShuffleSize = Layout::kShuffleSize;
  162. static constexpr int kShuffleBasicTile = Layout::kShuffleBasicTile;
  163. static constexpr int kShuffleContinous = Layout::kShuffleContinous;
  164. static constexpr int kShuffleStrided = Layout::kShuffleStrided;
  165. // The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4/8s_inplace
  166. // Input int8 data layout
  167. // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
  168. //
  169. // Converted fp16/bf16 data layout
  170. // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits)
  171. // Input int8 data layout
  172. // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits)
  173. //
  174. // Converted fp16/bf16 data layout
  175. // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits)
  176. static constexpr int kConvertCount = ConverterSelector<ActType, QType>::kConvertCount;
  177. using Converter = typename ConverterSelector<ActType, QType>::Converter;
  178. // Use ldg128 load data from global memory
  179. static constexpr int kAccessSize = 128;
  180. using AccessType = uint4;
  181. static constexpr int kElemsPerByte = 8 / kElemBits;
  182. static constexpr int kElemsPerThread = kAccessSize / kElemBits;
  183. static constexpr int kBytePerThread = kElemsPerThread / kElemsPerByte;
  184. static constexpr int kThreadsNumPerTile = kStride / kElemsPerThread;
  185. static constexpr int kThreadsNumPerInterleave = kThreadsNumPerTile * kInterleave;
  186. static constexpr int kConvertIters = kElemsPerThread / kConvertCount;
  187. // Each thread loads 16(int8b)/32(int4b) quantized weight elements each time through ldg128
  188. // So more times of ldg128 are needed to load the same number of fp16/bf16 activation elements.
  189. static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8);
  190. static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess;
  191. };
  192. template <typename WeightOnlyFlag>
  193. struct WeightOnlyProperties;
  194. template <>
  195. struct WeightOnlyProperties<WeightOnlyPerChannel>
  196. {
  197. static constexpr bool kIsFineGrained = false;
  198. static constexpr int kGroupSize = 0;
  199. };
  200. template <int GS>
  201. struct WeightOnlyProperties<WeightOnlyGroupWise<GS>>
  202. {
  203. static constexpr bool kIsFineGrained = true;
  204. static constexpr int kGroupSize = GS;
  205. };
  206. template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, bool Zero, int BlockSize>
  207. struct WeightOnlyScaleLoader
  208. {
  209. using ElemType = ActType;
  210. using Details = WeightOnlyKernelDetails<ActType, QType>;
  211. static constexpr bool kIsFineGrained = WeightOnlyProperties<WeightOnlyFlag>::kIsFineGrained;
  212. static constexpr int kGroupSize = WeightOnlyProperties<WeightOnlyFlag>::kGroupSize;
  213. private:
  214. const ElemType* _scales;
  215. const ElemType* _zeros;
  216. int _stride;
  217. int _offset;
  218. public:
  219. __device__ __forceinline__ WeightOnlyScaleLoader(
  220. const ElemType* scales, const ElemType* zeros, int initial_offset, int stride)
  221. : _scales(scales)
  222. , _zeros(zeros)
  223. , _stride(stride)
  224. {
  225. _scales += initial_offset;
  226. if constexpr (Zero)
  227. {
  228. _zeros += initial_offset;
  229. }
  230. // Calculate the k dimension index of the element processed by the current thread of layout before interleave
  231. // Used to load scales and zeros in groupwise weight only quant
  232. _offset = threadIdx.x / Details::kThreadsNumPerInterleave * Details::kStride
  233. + (threadIdx.x % Details::kThreadsNumPerTile) * Details::kElemsPerThread;
  234. }
  235. __device__ __forceinline__ void load(ElemType& scale, ElemType& zero, int nid)
  236. {
  237. int offset = nid * Details::kInterleave;
  238. if constexpr (kIsFineGrained)
  239. {
  240. offset += _offset / kGroupSize * _stride;
  241. }
  242. scale = _scales[offset];
  243. if constexpr (Zero)
  244. {
  245. zero = _zeros[offset];
  246. }
  247. else
  248. {
  249. zero = static_cast<ElemType>(0.f);
  250. }
  251. }
  252. __device__ __forceinline__ void advance()
  253. {
  254. _offset += BlockSize * Details::kElemsPerThread / Details::kInterleave;
  255. }
  256. __device__ __forceinline__ int offset()
  257. {
  258. return _offset;
  259. }
  260. };
  261. template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
  262. bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize>
  263. __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
  264. const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k)
  265. {
  266. static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0));
  267. using ActType2 = typename ActTypeDetails<ActType>::Vec2;
  268. using Details = WeightOnlyKernelDetails<ActType, QType>;
  269. using Converter = typename Details::Converter;
  270. using AccType = typename Details::AccessType;
  271. using CvtSrcType = typename Converter::source_type;
  272. using CvtResType = typename Converter::result_type;
  273. using ScaleLoader = WeightOnlyScaleLoader<ActType, QType, WeightOnlyFlag, Zero, BlockSize>;
  274. extern __shared__ uint8_t shmem[];
  275. constexpr int Interleave = Details::kInterleave;
  276. constexpr int WarpSize = 32;
  277. constexpr int Num = Batch * NPerBlock;
  278. const int tid = threadIdx.x;
  279. const int bid = blockIdx.x;
  280. const int n_start_id = bid * NPerBlock * Interleave;
  281. // Calculate the n-dimensional index of the data processed by the current thread in the interleave tile
  282. const int interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave;
  283. qweight += n_start_id * k / Details::kElemsPerByte;
  284. ScaleLoader scale_loader(scales, zeros, n_start_id + interleave_n_id, n);
  285. float(*sm)[Num * Interleave] = reinterpret_cast<float(*)[Num * Interleave]>(shmem);
  286. // In order to take advantage of hfma2, we use fp16/bf16 for accumulation within threads and fp32 for accumulation
  287. // between threads.
  288. ActType accumulator[Num];
  289. for (int i = 0; i < Num; ++i)
  290. {
  291. accumulator[i] = static_cast<ActType>(0.f);
  292. }
  293. // Iteration in k dimensions
  294. for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave;
  295. local_k += BlockSize * Details::kElemsPerThread)
  296. {
  297. ActType weights_f16[Details::kElemsPerThread * NPerBlock];
  298. ActType scale[NPerBlock], zero[NPerBlock];
  299. #pragma unroll
  300. for (int idx = 0; idx < NPerBlock; ++idx)
  301. {
  302. // Load quantized weight and scales/zeros
  303. uint8_t weights_quantized[Details::kBytePerThread];
  304. load<AccType>(weights_quantized,
  305. qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte);
  306. scale_loader.load(scale[idx], zero[idx], idx);
  307. ActType weights_vec[Details::kElemsPerThread];
  308. #pragma unroll
  309. for (int i = 0; i < Details::kConvertIters; ++i)
  310. {
  311. // Use cutlass::FastInterleavedAndBiasedNumericArrayConverter for I2F type conversion
  312. assign<CvtResType>(weights_vec + i * Details::kConvertCount,
  313. Converter::convert(*reinterpret_cast<CvtSrcType*>(
  314. weights_quantized + i * Details::kConvertCount / Details::kElemsPerByte)));
  315. }
  316. #pragma unroll
  317. for (int i = 0; i < Details::kShuffleContinous; ++i)
  318. {
  319. #pragma unroll
  320. for (int j = 0; j < Details::kShuffleStrided; ++j)
  321. {
  322. // Dequantize the weights and arrange the shuffled elements back to the correct order in the
  323. // register array
  324. ActType2 v = *reinterpret_cast<ActType2*>(weights_vec + i * Details::kShuffleBasicTile
  325. + j * Details::kShuffleContinous * Details::kShuffleBasicTile);
  326. v = __hfma2(
  327. v, ActTypeDetails<ActType>::to_vec2(scale[idx]), ActTypeDetails<ActType>::to_vec2(zero[idx]));
  328. weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
  329. + j * Details::kShuffleBasicTile + 0)
  330. * NPerBlock
  331. + idx]
  332. = v.x;
  333. weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
  334. + j * Details::kShuffleBasicTile + 1)
  335. * NPerBlock
  336. + idx]
  337. = v.y;
  338. }
  339. }
  340. }
  341. ActType act_scale_v[Details::kElemsPerThread];
  342. if constexpr (ActScale)
  343. {
  344. #pragma unroll
  345. for (int idx = 0; idx < Details::kActivationAccessNum; ++idx)
  346. {
  347. load<AccType>(act_scale_v + idx * Details::kActivationElemNumPerAccess,
  348. act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess);
  349. }
  350. }
  351. #pragma unroll
  352. for (int b = 0; b < Batch; ++b)
  353. {
  354. ActType in_v[Details::kElemsPerThread];
  355. #pragma unroll
  356. for (int idx = 0; idx < Details::kActivationAccessNum; ++idx)
  357. {
  358. // load activation elements
  359. load<AccType>(in_v + idx * Details::kActivationElemNumPerAccess,
  360. in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess);
  361. if constexpr (ActScale)
  362. {
  363. #pragma unroll
  364. for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2)
  365. {
  366. *reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2(
  367. *reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i),
  368. *reinterpret_cast<ActType2*>(act_scale_v + idx * Details::kActivationElemNumPerAccess + i));
  369. }
  370. }
  371. }
  372. // Perform vector inner product and accumulate
  373. if constexpr (NPerBlock == 1)
  374. {
  375. ActType2 v = ActTypeDetails<ActType>::to_vec2(static_cast<ActType>(0.f));
  376. #pragma unroll
  377. for (int y = 0; y < Details::kElemsPerThread; y += 2)
  378. {
  379. v = __hfma2(
  380. *reinterpret_cast<ActType2*>(weights_f16 + y), *reinterpret_cast<ActType2*>(in_v + y), v);
  381. }
  382. accumulator[b] += __hadd(v.x, v.y);
  383. }
  384. else
  385. {
  386. #pragma unroll
  387. for (int x = 0; x < NPerBlock / 2; ++x)
  388. {
  389. #pragma unroll
  390. for (int y = 0; y < Details::kElemsPerThread; ++y)
  391. {
  392. *reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2)
  393. = __hfma2(*reinterpret_cast<ActType2*>(weights_f16 + y * NPerBlock + x * 2),
  394. ActTypeDetails<ActType>::to_vec2(in_v[y]),
  395. *reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2));
  396. }
  397. }
  398. }
  399. }
  400. scale_loader.advance();
  401. }
  402. float reses[Num];
  403. #pragma unroll
  404. for (int i = 0; i < Num; ++i)
  405. {
  406. reses[i] = static_cast<float>(accumulator[i]);
  407. }
  408. // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
  409. // corresponding address in shared memory
  410. Details::Layout::sync<Num, WarpSize>(reses, sm);
  411. // Each thread is responsible for the accumulation and store to global memory of one element
  412. for (int i = tid; i < Num * Interleave; i += BlockSize)
  413. {
  414. int nid = i % (NPerBlock * Interleave);
  415. float v = 0.f;
  416. for (int j = 0; j < BlockSize / WarpSize; ++j)
  417. {
  418. v += sm[j][i];
  419. }
  420. float bias_v = 0.f;
  421. if constexpr (Bias)
  422. {
  423. bias_v = static_cast<float>(bias[n_start_id + nid]);
  424. }
  425. int b = i / NPerBlock / Interleave;
  426. out[b * n + n_start_id + nid] = static_cast<ActType>(ActOp<float>::apply(v + bias_v));
  427. }
  428. }
  429. template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
  430. bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize>
  431. __global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
  432. const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k)
  433. {
  434. if constexpr (std::is_same_v<ActType, half>)
  435. {
  436. weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch,
  437. BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k);
  438. }
  439. #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
  440. else if (std::is_same_v<ActType, nv_bfloat16>)
  441. {
  442. weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch,
  443. BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k);
  444. }
  445. #endif
  446. }
  447. template <WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, bool Zero, bool Bias,
  448. int NPerBlock, int Batch, int BlockSize>
  449. struct WeightOnlyBatchedGemvKernelLauncher
  450. {
  451. static void run(const WeightOnlyParams& params, cudaStream_t stream)
  452. {
  453. if (params.act_type == WeightOnlyActivationType::FP16)
  454. {
  455. constexpr int kInterleave = WeightOnlyDetails<half, QType>::kInterleave;
  456. dim3 grid(params.n / NPerBlock / kInterleave);
  457. dim3 block(BlockSize);
  458. int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
  459. if (params.act_scale != nullptr)
  460. {
  461. weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, NPerBlock, Batch,
  462. BlockSize><<<grid, block, size, stream>>>(params.qweight,
  463. reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros),
  464. reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale),
  465. reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n,
  466. params.k);
  467. }
  468. else
  469. {
  470. weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, NPerBlock,
  471. Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
  472. reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros),
  473. reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale),
  474. reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n,
  475. params.k);
  476. }
  477. }
  478. #if defined(ENABLE_BF16)
  479. else if (params.act_type == WeightOnlyActivationType::BF16)
  480. {
  481. constexpr int kInterleave = WeightOnlyDetails<nv_bfloat16, QType>::kInterleave;
  482. dim3 grid(params.n / NPerBlock / kInterleave);
  483. dim3 block(BlockSize);
  484. int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
  485. if (params.act_scale != nullptr)
  486. {
  487. weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true,
  488. NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
  489. reinterpret_cast<const __nv_bfloat16*>(params.scales),
  490. reinterpret_cast<const __nv_bfloat16*>(params.zeros),
  491. reinterpret_cast<const __nv_bfloat16*>(params.in),
  492. reinterpret_cast<const __nv_bfloat16*>(params.act_scale),
  493. reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
  494. params.n, params.k);
  495. }
  496. else
  497. {
  498. weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false,
  499. NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
  500. reinterpret_cast<const __nv_bfloat16*>(params.scales),
  501. reinterpret_cast<const __nv_bfloat16*>(params.zeros),
  502. reinterpret_cast<const __nv_bfloat16*>(params.in),
  503. reinterpret_cast<const __nv_bfloat16*>(params.act_scale),
  504. reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
  505. params.n, params.k);
  506. }
  507. }
  508. #endif
  509. }
  510. };
  511. } // namespace kernels
  512. } // namespace tensorrt_llm