generate_kernels.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602
  2. # This file is run to generate the kernel instantiations for the flash_attn kernels
  3. # They are written to several files in order to speed up compilation
  4. import argparse
  5. import itertools
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import List, Optional
  9. DTYPE_MAP = {
  10. "fp16": "cutlass::half_t",
  11. "bf16": "cutlass::bfloat16_t",
  12. "e4m3": "cutlass::float_e4m3_t",
  13. }
  14. DTYPE_MAP_BWD = {
  15. "fp16": "cutlass::half_t",
  16. "bf16": "cutlass::bfloat16_t",
  17. }
  18. SM = [90] # Sm80 kernels support up to
  19. HEAD_DIMENSIONS = [64, 96, 128, 192, 256]
  20. PAGEDKV = ["false", "true"]
  21. SPLIT = ["false", "true"]
  22. KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
  23. template<>
  24. void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}>(Flash_fwd_params &params, cudaStream_t stream) {{
  25. run_mha_fwd_16b<{DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}>(params, stream);
  26. }}
  27. """
  28. KERNEL_IMPL_TEMPLATE_FWD_FP8 = """#include "flash_fwd_launch_template.h"
  29. template<>
  30. void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}>(Flash_fwd_params &params, cudaStream_t stream) {{
  31. run_mha_fwd_8b<{DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}>(params, stream);
  32. }}
  33. """
  34. KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
  35. template<>
  36. void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream) {{
  37. run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
  38. }}
  39. """
  40. @dataclass
  41. class Kernel:
  42. sm: int
  43. dtype: str
  44. head_dim: int
  45. split: str
  46. paged_kv: str
  47. direction: str
  48. @property
  49. def template(self) -> str:
  50. if self.direction == "fwd" and self.dtype != "e4m3":
  51. return KERNEL_IMPL_TEMPLATE_FWD.format(
  52. DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, SPLIT=self.split, PAGEDKV=self.paged_kv
  53. )
  54. if self.direction == "fwd":
  55. return KERNEL_IMPL_TEMPLATE_FWD_FP8.format(
  56. DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, SPLIT=self.split, PAGEDKV=self.paged_kv
  57. )
  58. elif self.direction == "bwd":
  59. return KERNEL_IMPL_TEMPLATE_BWD.format(
  60. DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
  61. )
  62. @property
  63. def filename(self) -> str:
  64. return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'paged_' if self.paged_kv == 'true' else ''}{'split_' if self.split == 'true' else ''}sm{self.sm}.cu"
  65. def get_all_kernels() -> List[Kernel]:
  66. for dtype, head_dim, split, paged_kv, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SM):
  67. yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, direction="fwd")
  68. for dtype, head_dim, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SM):
  69. yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split='false', paged_kv='false', direction="bwd")
  70. def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
  71. prelude = """// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  72. // Splitting the different head dimensions to different files to speed up compilation.
  73. // This file is auto-generated. See "generate_kernels.py"\n
  74. """
  75. (autogen_dir / kernel.filename).write_text(prelude + kernel.template)
  76. def main(output_dir: Optional[str]) -> None:
  77. output_dir = Path(output_dir) if output_dir is not None else Path(__file__).parent
  78. output_dir.mkdir(parents=True, exist_ok=True)
  79. for kernel in get_all_kernels():
  80. write_kernel(kernel, output_dir)
  81. if __name__ == "__main__":
  82. parser = argparse.ArgumentParser(
  83. prog="generate_kernels",
  84. description="Generate the flash_attention kernels template instantiations",
  85. )
  86. # Set an optional output directory
  87. parser.add_argument(
  88. "-o",
  89. "--output_dir",
  90. default="instantiations",
  91. required=False,
  92. help="Where to generate the kernels "
  93. " will default to the current directory ",
  94. )
  95. args = parser.parse_args()
  96. main(args.output_dir)