/* * Adapted from https://github.com/InternLM/lmdeploy * Copyright (c) OpenMMLab. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include "gemm_s4_f16_kernel.h" #include "metric.h" namespace aphrodite { namespace autoquant { extern bool g_dump_kernel_info_once; enum Type { kGemm, kFusedSiluFfn }; template struct Impl{ using Kernels = std::vector>>; void Generate(std::vector& kernels); void Measure(T_BC* C, const uint* A, const T_BC* B, const T_Q* Q, int m, int n, int k, int group_size, Type type, std::vector& metrics, cudaStream_t st, std::vector& _kernels); static bool Compare(const Metric& a, const Metric& b) { if (a.feasible != b.feasible) { return a.feasible > b.feasible; } if (a.prefer != b.prefer) { return a.prefer > b.prefer; } return a.grid_norm < b.grid_norm; } int Estimate(int m, int n, int k, Kernels& kernels); void Run(T_BC* C, const uint* A, const T_BC* B, const T_Q* Q, int m, int n, int k, int group_size, Type type, int algo_id, cudaStream_t st, std::vector& kernels); Impl(); ~Impl(); std::vector kernels_; std::vector group_sizes_; static constexpr int kWarmup = 10; static constexpr int kMeasure = 100; cudaEvent_t ev_start_{}; cudaEvent_t ev_end_{}; }; template class GemmS4F16 { public: GemmS4F16(); ~GemmS4F16(); void Measure(T_BC* C, const uint* A, const T_BC* B, const T_Q* Q, int m, int n, int k, int group_size, Type type, std::vector& metrics, cudaStream_t st); void Run(T_BC* C, const uint* A, const T_BC* B, const T_Q* Q, int m, int n, int k, int group_size, Type type, int algo_id, cudaStream_t st); private: //struct Impl; std::unique_ptr> impl_; }; } // namespace autoquant } // namespace aphrodite