generate_kernels.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602
  2. # This file is run to generate the kernel instantiations for the flash_attn kernels
  3. # They are written to several files in order to speed up compilation
  4. import argparse
  5. import itertools
  6. from collections import namedtuple
  7. from dataclasses import dataclass
  8. from pathlib import Path
  9. from typing import List, Optional
  10. KERNEL_BATCH = namedtuple("Kernel", ["template", "filename"])
  11. DTYPE_MAP = {
  12. "fp16": "cutlass::half_t",
  13. "bf16": "cutlass::bfloat16_t",
  14. "e4m3": "cutlass::float_e4m3_t",
  15. }
  16. DTYPE_MAP_FWD_SM8x = {
  17. "fp16": "cutlass::half_t",
  18. "bf16": "cutlass::bfloat16_t",
  19. }
  20. DTYPE_MAP_BWD = {
  21. "fp16": "cutlass::half_t",
  22. "bf16": "cutlass::bfloat16_t",
  23. }
  24. SM = [80, 90] # Sm kernels support up to
  25. HEAD_DIMENSIONS = [64, 96, 128, 192, 256]
  26. PAGEDKV = [False, True]
  27. SPLIT = [False, True]
  28. SOFTCAP = [False, True]
  29. PACKGQA = [False, True]
  30. KERNEL_IMPL_TEMPLATE_FWD_SM90 = """#include "flash_fwd_launch_template.h"
  31. #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}
  32. template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params &params, cudaStream_t stream);
  33. #endif
  34. """
  35. KERNEL_IMPL_TEMPLATE_FWD_SM8x = """#include "flash_fwd_launch_template.h"
  36. #ifndef FLASHATTENTION_DISABLE_SM8x
  37. #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}
  38. template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params &params, cudaStream_t stream);
  39. template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params &params, cudaStream_t stream);
  40. #endif
  41. #endif
  42. """
  43. KERNEL_IMPL_TEMPLATE_BWD_SM90 = """#include "flash_bwd_launch_template.h"
  44. #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}
  45. template<>
  46. void run_mha_bwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params &params, cudaStream_t stream) {{
  47. run_mha_bwd_hdim{HEAD_DIM}<{ARCH}, {DTYPE}, {SOFTCAP}>(params, stream);
  48. }}
  49. #endif
  50. """
  51. KERNEL_IMPL_TEMPLATE_BWD_SM8x = """#include "flash_bwd_launch_template.h"
  52. #ifndef FLASHATTENTION_DISABLE_SM8x
  53. #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}
  54. template<>
  55. void run_mha_bwd_<80, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params &params, cudaStream_t stream) {{
  56. run_mha_bwd_hdim{HEAD_DIM}<80, {DTYPE}, {SOFTCAP}>(params, stream);
  57. }}
  58. template<>
  59. void run_mha_bwd_<86, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params &params, cudaStream_t stream) {{
  60. run_mha_bwd_hdim{HEAD_DIM}<86, {DTYPE}, {SOFTCAP}>(params, stream);
  61. }}
  62. #endif
  63. #endif
  64. """
  65. @dataclass
  66. class Kernel:
  67. sm: int
  68. dtype: str
  69. head_dim: int
  70. split: bool
  71. paged_kv: bool
  72. softcap: bool
  73. packgqa: bool
  74. direction: str
  75. @property
  76. def template(self) -> str:
  77. if self.direction == "fwd":
  78. if self.sm == 90:
  79. # Always enable PackGQA for PagedKV or Split to reduce compilation
  80. packgqa = self.packgqa or self.paged_kv or self.split
  81. return KERNEL_IMPL_TEMPLATE_FWD_SM90.format(
  82. ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,
  83. SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(),
  84. SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower()
  85. )
  86. else:
  87. # Always enable PackGQA for Sm8x to reduce compilation
  88. return KERNEL_IMPL_TEMPLATE_FWD_SM8x.format(
  89. DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,
  90. SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(),
  91. SOFTCAP=str(self.softcap).lower(), PACKGQA=str(True).lower()
  92. )
  93. elif self.direction == "bwd":
  94. if self.sm == 90:
  95. return KERNEL_IMPL_TEMPLATE_BWD_SM90.format(
  96. ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,
  97. SOFTCAP=str(self.softcap).lower()
  98. )
  99. else:
  100. return KERNEL_IMPL_TEMPLATE_BWD_SM8x.format(
  101. DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,
  102. SOFTCAP=str(self.softcap).lower()
  103. )
  104. @property
  105. def filename(self) -> str:
  106. return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu"
  107. def get_all_kernels() -> List[Kernel]:
  108. for dtype, head_dim, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):
  109. # We always enable PackGQA for Sm8x or PagedKV or Split
  110. # so we should just pass in packgqa=False to avoid the `_packgqa` in the filename.
  111. if packgqa and (sm < 90 or (sm >= 90 and (paged_kv or split))):
  112. continue
  113. if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x:
  114. yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd")
  115. for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM):
  116. yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd")
  117. def batch_hdim(kernels_all) -> List[KERNEL_BATCH]:
  118. for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):
  119. if sm < 90:
  120. continue
  121. kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm]
  122. if len(kernels) > 0:
  123. filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu"
  124. template = "\n".join([f"#include \"{k.filename}\"" for k in kernels])
  125. yield KERNEL_BATCH(template, filename)
  126. def batch_softcap(kernels_all) -> List[KERNEL_BATCH]:
  127. for dtype, head_dim, split, paged_kv, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, PACKGQA, SM):
  128. if sm >= 90:
  129. continue
  130. kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.head_dim == head_dim and k.split == split and k.paged_kv == paged_kv and k.packgqa == packgqa and k.sm == sm]
  131. if len(kernels) > 0:
  132. filename = f"flash_fwd_hdim{head_dim}_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}_softcapall{'_packgqa' if packgqa else ''}_sm{sm}.cu"
  133. template = "\n".join([f"#include \"{k.filename}\"" for k in kernels])
  134. yield KERNEL_BATCH(template, filename)
  135. # Bwd
  136. for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
  137. if sm < 90:
  138. continue
  139. kernels = [k for k in kernels_all if k.direction == "bwd" and k.dtype == dtype and k.head_dim == head_dim and k.sm == sm]
  140. if len(kernels) > 0:
  141. filename = f"flash_bwd_hdim{head_dim}_{dtype}_softcapall_sm{sm}.cu"
  142. template = "\n".join([f"#include \"{k.filename}\"" for k in kernels])
  143. yield KERNEL_BATCH(template, filename)
  144. def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
  145. prelude = """// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  146. // Splitting the different template instantiations to different files to speed up compilation.
  147. // This file is auto-generated. See "generate_kernels.py"\n
  148. """
  149. (autogen_dir / kernel.filename).write_text(prelude + kernel.template)
  150. def main(output_dir: Optional[str]) -> None:
  151. output_dir = Path(output_dir) if output_dir is not None else Path(__file__).parent
  152. output_dir.mkdir(parents=True, exist_ok=True)
  153. kernels_all = list(get_all_kernels())
  154. for kernel in kernels_all:
  155. write_kernel(kernel, output_dir)
  156. for kernel in batch_hdim(kernels_all):
  157. write_kernel(kernel, output_dir)
  158. for kernel in batch_softcap(kernels_all):
  159. write_kernel(kernel, output_dir)
  160. if __name__ == "__main__":
  161. parser = argparse.ArgumentParser(
  162. prog="generate_kernels",
  163. description="Generate the flash_attention kernels template instantiations",
  164. )
  165. # Set an optional output directory
  166. parser.add_argument(
  167. "-o",
  168. "--output_dir",
  169. default="instantiations",
  170. required=False,
  171. help="Where to generate the kernels "
  172. " will default to the current directory ",
  173. )
  174. args = parser.parse_args()
  175. main(args.output_dir)