/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common.h" #include "utility.h" namespace tensorrt_llm { namespace kernels { template class ActOp, bool Zero, bool Bias, int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> struct WeightOnlyBatchedGemvKernelLauncher { static void run(const WeightOnlyParams& params, cudaStream_t stream); }; template class ActOp, int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> void select_zero_bias(const WeightOnlyParams& params, cudaStream_t stream) { if (params.zeros && params.bias) { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); } else if (params.zeros && !params.bias) { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); } else if (!params.zeros && params.bias) { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); } else { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); } } template void select_activation(const WeightOnlyParams& params, cudaStream_t stream) { switch (params.act_func_type) { // Currently, activation function is not called in the plugin #if 0 case WeightOnlyActivationFunctionType::Gelu: { select_zero_bias(params, stream); break; } case WeightOnlyActivationFunctionType::Relu: { select_zero_bias(params, stream); break; } #endif case WeightOnlyActivationFunctionType::Identity: { select_zero_bias(params, stream); break; } default: { throw std::runtime_error("Use unsupported activation"); break; } } } template void select_quant_type(const WeightOnlyParams& params, cudaStream_t stream) { if (params.quant_type == WeightOnlyQuantType::Int4b) { select_activation(params, stream); } else if (params.quant_type == WeightOnlyQuantType::Int8b) { select_activation(params, stream); } else { throw std::runtime_error("Unknown QuantType"); } } template void select_groupwise_weight_only(const WeightOnlyParams& params, cudaStream_t stream) { if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 64) { select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); } else if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 128) { select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); } else { throw std::runtime_error("Only support groupwise weight only for gs=64/128"); } } void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream) { assert(params.act_func_type == WeightOnlyActivationFunctionType::Identity); assert(params.weight_only_type == WeightOnlyType::GroupWise || (params.weight_only_type == WeightOnlyType::PerChannel && params.bias == nullptr && params.zeros == nullptr)); if (params.weight_only_type == WeightOnlyType::PerChannel) { if (params.quant_type == WeightOnlyQuantType::Int4b) { switch (params.m) { case 1: { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 2: { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 3: { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 4: { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } default: { throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); break; } } } else if (params.quant_type == WeightOnlyQuantType::Int8b) { switch (params.m) { case 1: { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 2: { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 3: { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 4: { WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } default: { throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); break; } } } } else if (params.weight_only_type == WeightOnlyType::GroupWise) { switch (params.m) { case 1: { select_groupwise_weight_only<2, 1, 256>(params, stream); break; } case 2: { select_groupwise_weight_only<2, 2, 256>(params, stream); break; } case 3: { select_groupwise_weight_only<2, 3, 128>(params, stream); break; } case 4: { select_groupwise_weight_only<2, 4, 128>(params, stream); break; } default: { throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); break; } } } } } // namespace kernels } // namespace tensorrt_llm