12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- // Copyright 2024 FP6-LLM authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- //
- // This file is copied from
- // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/configs.h
- #ifndef CONFIGS_H
- #define CONFIGS_H
- // #define DEBUG_MODE
- #define PIPELINE_LEVEL_GMEM 2
- #define PIPELINE_LEVEL_SMEM 2 // only support 2
- /************************ Hardware Parameters ************************/
- #define WARP_SIZE 32
- #define REG_BIT_WIDTH 32
- // mma: M=16 K=16 N=8
- #define MMA_8 8
- #define MMA_16 16
- // for memory access
- #define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ...
- #define BIT_WIDTH_PER_HALF 16 // Half precision: FP16
- /******************** Register Allocation For GEMM ********************/
- #define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation
- /********************** Memory Padding Parameters **********************/
- // Eliminating bank-conflict
- #define PADDING_BYTES_16 16 // Padding 16 bytes each column
- #define PADDING_SHARED_MEM_FOR_B_8 \
- 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B
- #define PADDING_SHARED_MEM_FOR_C_4 \
- 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister()
- // for C
- /************************* WARP Tiling part-1 *************************/
- #define WARP_ROW_MMA_TENSORS 4
- #define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64
- #define WARP_K_MMA_TENSORS 4
- #define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64
- template <int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>
- struct TilingConfig {
- // Depending on "n" dimension of the GEMM
- static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_;
- static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_;
- static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_;
- /************************* WARP Tiling part-2 *************************/
- static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8;
- /*************************Thread Block Tiling *************************/
- static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS;
- static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;
- static constexpr int TILE_K = WARP_K;
- /********************** #Thread per Thread Block **********************/
- static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;
- static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;
- /******************************* Others *******************************/
- static constexpr int SMEM_SIZE_B_TILE =
- TILE_N * (TILE_K + PADDING_BYTES_16) * 2 *
- PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2
- static constexpr int SMEM_SIZE_C_TILE =
- TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4
- };
- #endif // CONFIGS_H
|