configs.h 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is copied from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/configs.h
  17. #ifndef CONFIGS_H
  18. #define CONFIGS_H
  19. // #define DEBUG_MODE
  20. #define PIPELINE_LEVEL_GMEM 2
  21. #define PIPELINE_LEVEL_SMEM 2 // only support 2
  22. /************************ Hardware Parameters ************************/
  23. #define WARP_SIZE 32
  24. #define REG_BIT_WIDTH 32
  25. // mma: M=16 K=16 N=8
  26. #define MMA_8 8
  27. #define MMA_16 16
  28. // for memory access
  29. #define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ...
  30. #define BIT_WIDTH_PER_HALF 16 // Half precision: FP16
  31. /******************** Register Allocation For GEMM ********************/
  32. #define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation
  33. /********************** Memory Padding Parameters **********************/
  34. // Eliminating bank-conflict
  35. #define PADDING_BYTES_16 16 // Padding 16 bytes each column
  36. #define PADDING_SHARED_MEM_FOR_B_8 \
  37. 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B
  38. #define PADDING_SHARED_MEM_FOR_C_4 \
  39. 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister()
  40. // for C
  41. /************************* WARP Tiling part-1 *************************/
  42. #define WARP_ROW_MMA_TENSORS 4
  43. #define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64
  44. #define WARP_K_MMA_TENSORS 4
  45. #define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64
  46. template <int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>
  47. struct TilingConfig {
  48. // Depending on "n" dimension of the GEMM
  49. static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_;
  50. static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_;
  51. static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_;
  52. /************************* WARP Tiling part-2 *************************/
  53. static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8;
  54. /*************************Thread Block Tiling *************************/
  55. static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS;
  56. static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;
  57. static constexpr int TILE_K = WARP_K;
  58. /********************** #Thread per Thread Block **********************/
  59. static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;
  60. static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;
  61. /******************************* Others *******************************/
  62. static constexpr int SMEM_SIZE_B_TILE =
  63. TILE_N * (TILE_K + PADDING_BYTES_16) * 2 *
  64. PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2
  65. static constexpr int SMEM_SIZE_C_TILE =
  66. TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4
  67. };
  68. #endif // CONFIGS_H