generate_kernels.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602
  2. # This file is run to generate the kernel instantiations for the flash_attn kernels
  3. # They are written to several files in order to speed up compilation
  4. import argparse
  5. import itertools
  6. from collections import namedtuple
  7. from dataclasses import dataclass
  8. from pathlib import Path
  9. from typing import List, Optional
  10. KERNEL_BATCH = namedtuple("Kernel", ["template", "filename"])
  11. DTYPE_MAP = {
  12. "fp16": "cutlass::half_t",
  13. "bf16": "cutlass::bfloat16_t",
  14. "e4m3": "cutlass::float_e4m3_t",
  15. }
  16. DTYPE_MAP_FWD_SM80 = {
  17. "fp16": "cutlass::half_t",
  18. "bf16": "cutlass::bfloat16_t",
  19. }
  20. DTYPE_MAP_BWD = {
  21. "fp16": "cutlass::half_t",
  22. "bf16": "cutlass::bfloat16_t",
  23. }
  24. SM = [80, 90] # Sm kernels support up to
  25. HEAD_DIMENSIONS = [64, 96, 128, 192, 256]
  26. PAGEDKV = [False, True]
  27. SPLIT = [False, True]
  28. SOFTCAP = [False, True]
  29. PACKGQA = [False, True]
  30. KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
  31. #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}
  32. template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params &params, cudaStream_t stream);
  33. #endif
  34. """
  35. KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
  36. #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}
  37. template<>
  38. void run_mha_bwd_<{ARCH}, {DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream) {{
  39. run_mha_bwd_hdim{HEAD_DIM}<{ARCH}, {DTYPE}>(params, stream);
  40. }}
  41. #endif
  42. """
  43. @dataclass
  44. class Kernel:
  45. sm: int
  46. dtype: str
  47. head_dim: int
  48. split: bool
  49. paged_kv: bool
  50. softcap: bool
  51. packgqa: bool
  52. direction: str
  53. @property
  54. def template(self) -> str:
  55. if self.direction == "fwd":
  56. return KERNEL_IMPL_TEMPLATE_FWD.format(
  57. ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,
  58. SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(),
  59. SOFTCAP=str(self.softcap).lower(), PACKGQA=str(self.packgqa).lower()
  60. )
  61. elif self.direction == "bwd":
  62. return KERNEL_IMPL_TEMPLATE_BWD.format(
  63. ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
  64. )
  65. @property
  66. def filename(self) -> str:
  67. return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu"
  68. def get_all_kernels() -> List[Kernel]:
  69. for dtype, head_dim, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):
  70. if sm >= 90 or dtype in DTYPE_MAP_FWD_SM80:
  71. yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd")
  72. for dtype, head_dim, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SM):
  73. yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=False, paged_kv=False, softcap=False, packgqa=False, direction="bwd")
  74. def batch_hdim(kernels_all) -> List[KERNEL_BATCH]:
  75. for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):
  76. kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm]
  77. if len(kernels) > 0:
  78. filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu"
  79. template = "\n".join([f"#include \"{k.filename}\"" for k in kernels])
  80. yield KERNEL_BATCH(template, filename)
  81. def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
  82. prelude = """// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  83. // Splitting the different template instantiations to different files to speed up compilation.
  84. // This file is auto-generated. See "generate_kernels.py"\n
  85. """
  86. (autogen_dir / kernel.filename).write_text(prelude + kernel.template)
  87. def main(output_dir: Optional[str]) -> None:
  88. output_dir = Path(output_dir) if output_dir is not None else Path(__file__).parent
  89. output_dir.mkdir(parents=True, exist_ok=True)
  90. kernels_all = list(get_all_kernels())
  91. for kernel in kernels_all:
  92. write_kernel(kernel, output_dir)
  93. for kernel in batch_hdim(kernels_all):
  94. write_kernel(kernel, output_dir)
  95. if __name__ == "__main__":
  96. parser = argparse.ArgumentParser(
  97. prog="generate_kernels",
  98. description="Generate the flash_attention kernels template instantiations",
  99. )
  100. # Set an optional output directory
  101. parser.add_argument(
  102. "-o",
  103. "--output_dir",
  104. default="instantiations",
  105. required=False,
  106. help="Where to generate the kernels "
  107. " will default to the current directory ",
  108. )
  109. args = parser.parse_args()
  110. main(args.output_dir)