123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- # Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602
- # This file is run to generate the kernel instantiations for the flash_attn kernels
- # They are written to several files in order to speed up compilation
- import argparse
- import itertools
- from collections import namedtuple
- from dataclasses import dataclass
- from pathlib import Path
- from typing import List, Optional
- KERNEL_BATCH = namedtuple("Kernel", ["template", "filename"])
- DTYPE_MAP = {
- "fp16": "cutlass::half_t",
- "bf16": "cutlass::bfloat16_t",
- "e4m3": "cutlass::float_e4m3_t",
- }
- DTYPE_MAP_FWD_SM80 = {
- "fp16": "cutlass::half_t",
- "bf16": "cutlass::bfloat16_t",
- }
- DTYPE_MAP_BWD = {
- "fp16": "cutlass::half_t",
- "bf16": "cutlass::bfloat16_t",
- }
- SM = [80, 90] # Sm kernels support up to
- HEAD_DIMENSIONS = [64, 96, 128, 192, 256]
- PAGEDKV = [False, True]
- SPLIT = [False, True]
- SOFTCAP = [False, True]
- PACKGQA = [False, True]
- KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
- #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}
- template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream);
- #endif
- """
- KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
- #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}
- template<>
- void run_mha_bwd_<{ARCH}, {DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{
- run_mha_bwd_hdim{HEAD_DIM}<{ARCH}, {DTYPE}>(params, stream);
- }}
- #endif
- """
- @dataclass
- class Kernel:
- sm: int
- dtype: str
- head_dim: int
- split: bool
- paged_kv: bool
- softcap: bool
- packgqa: bool
- direction: str
- @property
- def template(self) -> str:
- if self.direction == "fwd":
- return KERNEL_IMPL_TEMPLATE_FWD.format(
- ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,
- SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(),
- SOFTCAP=str(self.softcap).lower(), PACKGQA=str(self.packgqa).lower()
- )
- elif self.direction == "bwd":
- return KERNEL_IMPL_TEMPLATE_BWD.format(
- ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
- )
- @property
- def filename(self) -> str:
- return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu"
- def get_all_kernels() -> List[Kernel]:
- for dtype, head_dim, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):
- if sm >= 90 or dtype in DTYPE_MAP_FWD_SM80:
- yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd")
- for dtype, head_dim, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SM):
- yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=False, paged_kv=False, softcap=False, packgqa=False, direction="bwd")
- def batch_hdim(kernels_all) -> List[KERNEL_BATCH]:
- for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):
- kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm]
- if len(kernels) > 0:
- filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu"
- template = "\n".join([f"#include \"{k.filename}\"" for k in kernels])
- yield KERNEL_BATCH(template, filename)
- def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
- prelude = """// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- // Splitting the different template instantiations to different files to speed up compilation.
- // This file is auto-generated. See "generate_kernels.py"\n
- """
- (autogen_dir / kernel.filename).write_text(prelude + kernel.template)
- def main(output_dir: Optional[str]) -> None:
- output_dir = Path(output_dir) if output_dir is not None else Path(__file__).parent
- output_dir.mkdir(parents=True, exist_ok=True)
- kernels_all = list(get_all_kernels())
- for kernel in kernels_all:
- write_kernel(kernel, output_dir)
- for kernel in batch_hdim(kernels_all):
- write_kernel(kernel, output_dir)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- prog="generate_kernels",
- description="Generate the flash_attention kernels template instantiations",
- )
- # Set an optional output directory
- parser.add_argument(
- "-o",
- "--output_dir",
- default="instantiations",
- required=False,
- help="Where to generate the kernels "
- " will default to the current directory ",
- )
- args = parser.parse_args()
- main(args.output_dir)
|