123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554 |
- /*
- * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #pragma once
- #include "common.h"
- #include "utility.h"
- namespace tensorrt_llm
- {
- namespace kernels
- {
- template <typename ActType>
- struct ActTypeDetails;
- template <>
- struct ActTypeDetails<half>
- {
- using CutlassType = cutlass::half_t;
- using Vec2 = half2;
- __device__ __forceinline__ static Vec2 to_vec2(half v)
- {
- return __half2half2(v);
- }
- };
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
- template <>
- struct ActTypeDetails<__nv_bfloat16>
- {
- using CutlassType = cutlass::bfloat16_t;
- using Vec2 = __nv_bfloat162;
- __device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v)
- {
- return __bfloat162bfloat162(v);
- }
- };
- #endif
- template <typename ActType, WeightOnlyQuantType QType>
- struct ConverterSelector
- {
- static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b);
- using WeiType = std::conditional_t<QType == WeightOnlyQuantType::Int4b, cutlass::uint4b_t, uint8_t>;
- static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4;
- using Converter
- = cutlass::FastInterleavedAndBiasedNumericArrayConverter<typename ActTypeDetails<ActType>::CutlassType, WeiType,
- kConvertCount>;
- };
- template <typename ActType, WeightOnlyQuantType QType>
- struct WeightOnlyDetails;
- template <typename ActType>
- struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int4b>
- {
- // Every four rows of the original weights are interleaved into a row with stride of 64, so if each thread
- // processes 32 elements(for int4, we can use ldg.128 to load weights), then every group of two adjacent threads
- // will alternately process four different row weights
- // for example
- // every 256 consecutive int4 elements [256*i, 256*(i+1)-1] of row N under interleave layout,
- // the first 64 are from [64*i, 64*(i+1)-1] of row 4N before interleaving,
- // and the second 64 are from [64*i, 64*(i+1)-1] of row 4N+1 before interleaving, and so on.
- // So if each thread loads 32 int4 elements, then the elements of each 2 adjacent threads of each 8
- // consecutive threads will come from row 4N ~ 4N+3 respectively before interleaving.
- static constexpr int kElemBits = 4;
- static constexpr int kInterleave = 4;
- static constexpr int kStride = 64;
- // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm
- // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 31
- // weight 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
- static constexpr int kShuffleSize = 32;
- static constexpr int kShuffleBasicTile = 2;
- static constexpr int kShuffleContinous = 4;
- static constexpr int kShuffleStrided = 4;
- // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
- // corresponding address in shared memory
- template <int Num, int WarpSize>
- __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave])
- {
- #pragma unroll
- for (int i = 0; i < Num; ++i)
- {
- res[i] += __shfl_xor_sync(~0, res[i], 16);
- res[i] += __shfl_xor_sync(~0, res[i], 8);
- res[i] += __shfl_xor_sync(~0, res[i], 1);
- }
- __syncthreads();
- int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
- if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
- {
- #pragma unroll
- for (int i = 0; i < Num; ++i)
- {
- sm[warp][i * kInterleave + lane / 2] = res[i];
- }
- }
- __syncthreads();
- }
- };
- template <typename ActType>
- struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int8b>
- {
- // Every two rows of the original weights are interleaved into a row with stride of 64, so if each thread
- // processes 16 elements(for int8, we can use ldg.128 to load weights), then every group of four adjacent threads
- // will alternately process two different row weights
- // for example
- // every 128 consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave layout,
- // the first 64 are from [64*i, 64*(i+1)-1] of row 2N before interleaving,
- // and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 before interleaving.
- // So if each thread loads 16 int8 elements, then the elements of the first four and last four threads of each 8
- // consecutive threads will come from row 2N and row 2N+1 respectively before interleaving.
- static constexpr int kElemBits = 8;
- static constexpr int kInterleave = 2;
- static constexpr int kStride = 64;
- // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm
- // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
- // weight 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
- static constexpr int kShuffleSize = 16;
- static constexpr int kShuffleBasicTile = 2;
- static constexpr int kShuffleContinous = 2;
- static constexpr int kShuffleStrided = 4;
- // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
- // corresponding address in shared memory
- template <int Num, int WarpSize>
- __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave])
- {
- #pragma unroll
- for (int i = 0; i < Num; ++i)
- {
- res[i] += __shfl_xor_sync(~0, res[i], 16);
- res[i] += __shfl_xor_sync(~0, res[i], 8);
- res[i] += __shfl_xor_sync(~0, res[i], 2);
- res[i] += __shfl_xor_sync(~0, res[i], 1);
- }
- __syncthreads();
- int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
- if (lane == 0 || lane == 4)
- {
- #pragma unroll
- for (int i = 0; i < Num; ++i)
- {
- sm[warp][i * kInterleave + lane / 4] = res[i];
- }
- }
- __syncthreads();
- }
- };
- template <typename ActType, WeightOnlyQuantType QType>
- struct WeightOnlyKernelDetails
- {
- using Layout = WeightOnlyDetails<ActType, QType>;
- static constexpr int kElemBits = Layout::kElemBits;
- static constexpr int kInterleave = Layout::kInterleave;
- static constexpr int kStride = Layout::kStride;
- static constexpr int kShuffleSize = Layout::kShuffleSize;
- static constexpr int kShuffleBasicTile = Layout::kShuffleBasicTile;
- static constexpr int kShuffleContinous = Layout::kShuffleContinous;
- static constexpr int kShuffleStrided = Layout::kShuffleStrided;
- // The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4/8s_inplace
- // Input int8 data layout
- // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
- //
- // Converted fp16/bf16 data layout
- // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits)
- // Input int8 data layout
- // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits)
- //
- // Converted fp16/bf16 data layout
- // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits)
- static constexpr int kConvertCount = ConverterSelector<ActType, QType>::kConvertCount;
- using Converter = typename ConverterSelector<ActType, QType>::Converter;
- // Use ldg128 load data from global memory
- static constexpr int kAccessSize = 128;
- using AccessType = uint4;
- static constexpr int kElemsPerByte = 8 / kElemBits;
- static constexpr int kElemsPerThread = kAccessSize / kElemBits;
- static constexpr int kBytePerThread = kElemsPerThread / kElemsPerByte;
- static constexpr int kThreadsNumPerTile = kStride / kElemsPerThread;
- static constexpr int kThreadsNumPerInterleave = kThreadsNumPerTile * kInterleave;
- static constexpr int kConvertIters = kElemsPerThread / kConvertCount;
- // Each thread loads 16(int8b)/32(int4b) quantized weight elements each time through ldg128
- // So more times of ldg128 are needed to load the same number of fp16/bf16 activation elements.
- static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8);
- static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess;
- };
- template <typename WeightOnlyFlag>
- struct WeightOnlyProperties;
- template <>
- struct WeightOnlyProperties<WeightOnlyPerChannel>
- {
- static constexpr bool kIsFineGrained = false;
- static constexpr int kGroupSize = 0;
- };
- template <int GS>
- struct WeightOnlyProperties<WeightOnlyGroupWise<GS>>
- {
- static constexpr bool kIsFineGrained = true;
- static constexpr int kGroupSize = GS;
- };
- template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, bool Zero, int BlockSize>
- struct WeightOnlyScaleLoader
- {
- using ElemType = ActType;
- using Details = WeightOnlyKernelDetails<ActType, QType>;
- static constexpr bool kIsFineGrained = WeightOnlyProperties<WeightOnlyFlag>::kIsFineGrained;
- static constexpr int kGroupSize = WeightOnlyProperties<WeightOnlyFlag>::kGroupSize;
- private:
- const ElemType* _scales;
- const ElemType* _zeros;
- int _stride;
- int _offset;
- public:
- __device__ __forceinline__ WeightOnlyScaleLoader(
- const ElemType* scales, const ElemType* zeros, int initial_offset, int stride)
- : _scales(scales)
- , _zeros(zeros)
- , _stride(stride)
- {
- _scales += initial_offset;
- if constexpr (Zero)
- {
- _zeros += initial_offset;
- }
- // Calculate the k dimension index of the element processed by the current thread of layout before interleave
- // Used to load scales and zeros in groupwise weight only quant
- _offset = threadIdx.x / Details::kThreadsNumPerInterleave * Details::kStride
- + (threadIdx.x % Details::kThreadsNumPerTile) * Details::kElemsPerThread;
- }
- __device__ __forceinline__ void load(ElemType& scale, ElemType& zero, int nid)
- {
- int offset = nid * Details::kInterleave;
- if constexpr (kIsFineGrained)
- {
- offset += _offset / kGroupSize * _stride;
- }
- scale = _scales[offset];
- if constexpr (Zero)
- {
- zero = _zeros[offset];
- }
- else
- {
- zero = static_cast<ElemType>(0.f);
- }
- }
- __device__ __forceinline__ void advance()
- {
- _offset += BlockSize * Details::kElemsPerThread / Details::kInterleave;
- }
- __device__ __forceinline__ int offset()
- {
- return _offset;
- }
- };
- template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
- bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize>
- __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
- const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k)
- {
- static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0));
- using ActType2 = typename ActTypeDetails<ActType>::Vec2;
- using Details = WeightOnlyKernelDetails<ActType, QType>;
- using Converter = typename Details::Converter;
- using AccType = typename Details::AccessType;
- using CvtSrcType = typename Converter::source_type;
- using CvtResType = typename Converter::result_type;
- using ScaleLoader = WeightOnlyScaleLoader<ActType, QType, WeightOnlyFlag, Zero, BlockSize>;
- extern __shared__ uint8_t shmem[];
- constexpr int Interleave = Details::kInterleave;
- constexpr int WarpSize = 32;
- constexpr int Num = Batch * NPerBlock;
- const int tid = threadIdx.x;
- const int bid = blockIdx.x;
- const int n_start_id = bid * NPerBlock * Interleave;
- // Calculate the n-dimensional index of the data processed by the current thread in the interleave tile
- const int interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave;
- qweight += n_start_id * k / Details::kElemsPerByte;
- ScaleLoader scale_loader(scales, zeros, n_start_id + interleave_n_id, n);
- float(*sm)[Num * Interleave] = reinterpret_cast<float(*)[Num * Interleave]>(shmem);
- // In order to take advantage of hfma2, we use fp16/bf16 for accumulation within threads and fp32 for accumulation
- // between threads.
- ActType accumulator[Num];
- for (int i = 0; i < Num; ++i)
- {
- accumulator[i] = static_cast<ActType>(0.f);
- }
- // Iteration in k dimensions
- for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave;
- local_k += BlockSize * Details::kElemsPerThread)
- {
- ActType weights_f16[Details::kElemsPerThread * NPerBlock];
- ActType scale[NPerBlock], zero[NPerBlock];
- #pragma unroll
- for (int idx = 0; idx < NPerBlock; ++idx)
- {
- // Load quantized weight and scales/zeros
- uint8_t weights_quantized[Details::kBytePerThread];
- load<AccType>(weights_quantized,
- qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte);
- scale_loader.load(scale[idx], zero[idx], idx);
- ActType weights_vec[Details::kElemsPerThread];
- #pragma unroll
- for (int i = 0; i < Details::kConvertIters; ++i)
- {
- // Use cutlass::FastInterleavedAndBiasedNumericArrayConverter for I2F type conversion
- assign<CvtResType>(weights_vec + i * Details::kConvertCount,
- Converter::convert(*reinterpret_cast<CvtSrcType*>(
- weights_quantized + i * Details::kConvertCount / Details::kElemsPerByte)));
- }
- #pragma unroll
- for (int i = 0; i < Details::kShuffleContinous; ++i)
- {
- #pragma unroll
- for (int j = 0; j < Details::kShuffleStrided; ++j)
- {
- // Dequantize the weights and arrange the shuffled elements back to the correct order in the
- // register array
- ActType2 v = *reinterpret_cast<ActType2*>(weights_vec + i * Details::kShuffleBasicTile
- + j * Details::kShuffleContinous * Details::kShuffleBasicTile);
- v = __hfma2(
- v, ActTypeDetails<ActType>::to_vec2(scale[idx]), ActTypeDetails<ActType>::to_vec2(zero[idx]));
- weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
- + j * Details::kShuffleBasicTile + 0)
- * NPerBlock
- + idx]
- = v.x;
- weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
- + j * Details::kShuffleBasicTile + 1)
- * NPerBlock
- + idx]
- = v.y;
- }
- }
- }
- ActType act_scale_v[Details::kElemsPerThread];
- if constexpr (ActScale)
- {
- #pragma unroll
- for (int idx = 0; idx < Details::kActivationAccessNum; ++idx)
- {
- load<AccType>(act_scale_v + idx * Details::kActivationElemNumPerAccess,
- act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess);
- }
- }
- #pragma unroll
- for (int b = 0; b < Batch; ++b)
- {
- ActType in_v[Details::kElemsPerThread];
- #pragma unroll
- for (int idx = 0; idx < Details::kActivationAccessNum; ++idx)
- {
- // load activation elements
- load<AccType>(in_v + idx * Details::kActivationElemNumPerAccess,
- in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess);
- if constexpr (ActScale)
- {
- #pragma unroll
- for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2)
- {
- *reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2(
- *reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i),
- *reinterpret_cast<ActType2*>(act_scale_v + idx * Details::kActivationElemNumPerAccess + i));
- }
- }
- }
- // Perform vector inner product and accumulate
- if constexpr (NPerBlock == 1)
- {
- ActType2 v = ActTypeDetails<ActType>::to_vec2(static_cast<ActType>(0.f));
- #pragma unroll
- for (int y = 0; y < Details::kElemsPerThread; y += 2)
- {
- v = __hfma2(
- *reinterpret_cast<ActType2*>(weights_f16 + y), *reinterpret_cast<ActType2*>(in_v + y), v);
- }
- accumulator[b] += __hadd(v.x, v.y);
- }
- else
- {
- #pragma unroll
- for (int x = 0; x < NPerBlock / 2; ++x)
- {
- #pragma unroll
- for (int y = 0; y < Details::kElemsPerThread; ++y)
- {
- *reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2)
- = __hfma2(*reinterpret_cast<ActType2*>(weights_f16 + y * NPerBlock + x * 2),
- ActTypeDetails<ActType>::to_vec2(in_v[y]),
- *reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2));
- }
- }
- }
- }
- scale_loader.advance();
- }
- float reses[Num];
- #pragma unroll
- for (int i = 0; i < Num; ++i)
- {
- reses[i] = static_cast<float>(accumulator[i]);
- }
- // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
- // corresponding address in shared memory
- Details::Layout::sync<Num, WarpSize>(reses, sm);
- // Each thread is responsible for the accumulation and store to global memory of one element
- for (int i = tid; i < Num * Interleave; i += BlockSize)
- {
- int nid = i % (NPerBlock * Interleave);
- float v = 0.f;
- for (int j = 0; j < BlockSize / WarpSize; ++j)
- {
- v += sm[j][i];
- }
- float bias_v = 0.f;
- if constexpr (Bias)
- {
- bias_v = static_cast<float>(bias[n_start_id + nid]);
- }
- int b = i / NPerBlock / Interleave;
- out[b * n + n_start_id + nid] = static_cast<ActType>(ActOp<float>::apply(v + bias_v));
- }
- }
- template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
- bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize>
- __global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
- const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k)
- {
- if constexpr (std::is_same_v<ActType, half>)
- {
- weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch,
- BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k);
- }
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
- else if (std::is_same_v<ActType, nv_bfloat16>)
- {
- weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch,
- BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k);
- }
- #endif
- }
- template <WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, bool Zero, bool Bias,
- int NPerBlock, int Batch, int BlockSize>
- struct WeightOnlyBatchedGemvKernelLauncher
- {
- static void run(const WeightOnlyParams& params, cudaStream_t stream)
- {
- if (params.act_type == WeightOnlyActivationType::FP16)
- {
- constexpr int kInterleave = WeightOnlyDetails<half, QType>::kInterleave;
- dim3 grid(params.n / NPerBlock / kInterleave);
- dim3 block(BlockSize);
- int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
- if (params.act_scale != nullptr)
- {
- weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, NPerBlock, Batch,
- BlockSize><<<grid, block, size, stream>>>(params.qweight,
- reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros),
- reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale),
- reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n,
- params.k);
- }
- else
- {
- weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, NPerBlock,
- Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
- reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros),
- reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale),
- reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n,
- params.k);
- }
- }
- #if defined(ENABLE_BF16)
- else if (params.act_type == WeightOnlyActivationType::BF16)
- {
- constexpr int kInterleave = WeightOnlyDetails<nv_bfloat16, QType>::kInterleave;
- dim3 grid(params.n / NPerBlock / kInterleave);
- dim3 block(BlockSize);
- int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
- if (params.act_scale != nullptr)
- {
- weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true,
- NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
- reinterpret_cast<const __nv_bfloat16*>(params.scales),
- reinterpret_cast<const __nv_bfloat16*>(params.zeros),
- reinterpret_cast<const __nv_bfloat16*>(params.in),
- reinterpret_cast<const __nv_bfloat16*>(params.act_scale),
- reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
- params.n, params.k);
- }
- else
- {
- weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false,
- NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
- reinterpret_cast<const __nv_bfloat16*>(params.scales),
- reinterpret_cast<const __nv_bfloat16*>(params.zeros),
- reinterpret_cast<const __nv_bfloat16*>(params.in),
- reinterpret_cast<const __nv_bfloat16*>(params.act_scale),
- reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
- params.n, params.k);
- }
- }
- #endif
- }
- };
- } // namespace kernels
- } // namespace tensorrt_llm
|