1
0

generate.py 18 KB


  1. import itertools
  2. import math
  3. import os
  4. import shutil
  5. from collections.abc import Iterable
  6. from dataclasses import dataclass
  7. from typing import List, Optional, Tuple, Union
  8. import jinja2
  9. # yapf conflicts with isort for this block
  10. # yapf: disable
  11. from aphrodite_cutlass_library_extension import (APHRODITEDataType,
  12. APHRODITEDataTypeNames,
  13. APHRODITEDataTypeTag,
  14. APHRODITEKernelScheduleTag,
  15. DataType, EpilogueScheduleTag,
  16. EpilogueScheduleType,
  17. MixedInputKernelScheduleType,
  18. TileSchedulerTag,
  19. TileSchedulerType)
  20. # yapf: enable
  21. #
  22. # Generator templating
  23. #
  24. DISPATCH_TEMPLATE = """
  25. #include "../machete_mm_launcher.cuh"
  26. namespace machete {
  27. using GemmDispatcher_ = GemmDispatcher<
  28. {{DataTypeTag[type_config.element_a]}}, // ElementA
  29. {{DataTypeTag[type_config.element_b]}}, // ElementB
  30. {{DataTypeTag[type_config.element_d]}}, // ElementD
  31. {{DataTypeTag[type_config.accumulator]}}, // Accumulator
  32. {{DataTypeTag[type_config.element_b_scale]}}, // Scales
  33. {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
  34. {% for s in schedules %}extern torch::Tensor
  35. impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
  36. {% endfor %}
  37. template <>
  38. torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
  39. [[maybe_unused]] auto M = args.A.size(0);
  40. [[maybe_unused]] auto N = args.B.size(1);
  41. [[maybe_unused]] auto K = args.A.size(1);
  42. if (!args.schedule) {
  43. {%- for cond, s in heuristic %}
  44. {%if cond is not none%}if ({{cond}})
  45. {%- else %}else
  46. {%- endif %}
  47. return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
  48. }
  49. {% for s in schedules %}
  50. if (*args.schedule == "{{ gen_sch_name(s) }}") {
  51. return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
  52. }
  53. {% endfor %}
  54. TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
  55. "schedule = ", *args.schedule);
  56. }
  57. template <>
  58. std::vector<std::string> GemmDispatcher_::supported_schedules() {
  59. return {
  60. {% for s in schedules -%}
  61. "{{ gen_sch_name(s) }}"{{ ",
  62. " if not loop.last }}{%- endfor %}
  63. };
  64. }
  65. }; // namespace machete
  66. """
  67. IMPL_TEMPLATE = """
  68. #include "../machete_mm_launcher.cuh"
  69. namespace machete {
  70. template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
  71. using Kernel = MacheteKernelTemplate<
  72. {{DataTypeTag[type_config.element_a]}}, // ElementA
  73. {{DataTypeTag[type_config.element_b]}}, // ElementB
  74. {{DataTypeTag[type_config.element_d]}}, // ElementD
  75. {{DataTypeTag[type_config.accumulator]}}, // Accumulator
  76. {{DataTypeTag[type_config.element_b_scale]}}, // Scales
  77. {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
  78. cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
  79. Config, with_C, with_scales, with_zeropoints>;
  80. {% for sch in schedules %}
  81. {% set schedule_name = gen_sch_name(sch) -%}
  82. struct sch_{{schedule_name}} {
  83. using TileShapeNM = Shape<{{
  84. to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
  85. using ClusterShape = Shape<{{
  86. to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
  87. // TODO: Reimplement
  88. // using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
  89. using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
  90. using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
  91. using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
  92. };
  93. torch::Tensor
  94. impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
  95. bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
  96. with_zeropoints = args.zeros.has_value();
  97. {% for s in specializations %}
  98. if (with_C == {{s.with_C|lower}}
  99. && with_zeropoints == {{s.with_zeropoints|lower}}
  100. && with_scales == {{s.with_scales|lower}}) {
  101. return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
  102. {{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
  103. }{% endfor %}
  104. TORCH_CHECK_NOT_IMPLEMENTED(
  105. false, "for the sake of compile times and binary size machete_mm(..) is "
  106. " not implemented for with_C=", with_C, ", with_scales=", with_scales,
  107. ", with_zeropoints=", with_zeropoints,
  108. " (for {{type_name}}_sch_{{schedule_name}})");
  109. }
  110. {% endfor %}
  111. }; // namespace machete
  112. """
  113. PREPACK_TEMPLATE = """
  114. #include "../machete_prepack_launcher.cuh"
  115. namespace machete {
  116. using PrepackBDispatcher_ = PrepackBDispatcher<
  117. {{DataTypeTag[type_config.element_a]}}, // ElementA
  118. {{DataTypeTag[type_config.element_b]}}, // ElementB
  119. {{DataTypeTag[type_config.element_d]}}, // ElementD
  120. {{DataTypeTag[type_config.accumulator]}}, // Accumulator
  121. {{DataTypeTag[type_config.element_b_scale]}}, // Scales
  122. {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
  123. using PrepackedLayoutB = PrepackedLayoutBTemplate<
  124. {{DataTypeTag[type_config.element_a]}}, // ElementA
  125. {{DataTypeTag[type_config.element_b]}}, // ElementB
  126. {{DataTypeTag[type_config.element_d]}}, // ElementD
  127. {{DataTypeTag[type_config.accumulator]}}, // Accumulator
  128. cutlass::layout::ColumnMajor,
  129. cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
  130. template <>
  131. torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
  132. return prepack_impl<PrepackedLayoutB>(B);
  133. }
  134. }; // namespace machete
  135. """
  136. TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
  137. TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
  138. @dataclass(frozen=True)
  139. class ScheduleConfig:
  140. tile_shape_mn: Tuple[int, int]
  141. cluster_shape_mnk: Tuple[int, int, int]
  142. kernel_schedule: MixedInputKernelScheduleType
  143. epilogue_schedule: EpilogueScheduleType
  144. tile_scheduler: TileSchedulerType
  145. @dataclass
  146. class TypeConfig:
  147. element_a: DataType
  148. element_b: Union[DataType, APHRODITEDataType]
  149. element_b_scale: DataType
  150. element_b_zeropoint: DataType
  151. element_d: DataType
  152. accumulator: DataType
  153. @dataclass
  154. class Specialization:
  155. with_C: bool
  156. with_zeropoints: bool
  157. with_scales: bool
  158. @dataclass
  159. class ImplConfig:
  160. type_config: TypeConfig
  161. schedule_configs: List[ScheduleConfig]
  162. specializations: List[Specialization]
  163. heuristic: List[Tuple[Optional[str], ScheduleConfig]]
  164. def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
  165. tile_shape = (
  166. f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
  167. )
  168. cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
  169. f"x{schedule_config.cluster_shape_mnk[1]}" +
  170. f"x{schedule_config.cluster_shape_mnk[2]}")
  171. kernel_schedule = APHRODITEKernelScheduleTag[
  172. schedule_config.kernel_schedule]\
  173. .split("::")[-1]
  174. epilogue_schedule = EpilogueScheduleTag[
  175. schedule_config.epilogue_schedule].split("::")[-1]
  176. tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
  177. .split("::")[-1]
  178. return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
  179. f"_{epilogue_schedule}_{tile_scheduler}")
  180. # mostly unique shorter schedule_name
  181. def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
  182. kernel_terse_names_replace = {
  183. "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
  184. "TmaWarpSpecializedCooperative_": "TmaCoop_",
  185. "StreamKScheduler": "streamK",
  186. }
  187. schedule_name = generate_schedule_name(schedule_config)
  188. for orig, terse in kernel_terse_names_replace.items():
  189. schedule_name = schedule_name.replace(orig, terse)
  190. return schedule_name
  191. # unique type_name
  192. def generate_type_signature(kernel_type_config: TypeConfig):
  193. element_a = APHRODITEDataTypeNames[kernel_type_config.element_a]
  194. element_b = APHRODITEDataTypeNames[kernel_type_config.element_b]
  195. element_d = APHRODITEDataTypeNames[kernel_type_config.element_d]
  196. accumulator = APHRODITEDataTypeNames[kernel_type_config.accumulator]
  197. element_scale = APHRODITEDataTypeNames[kernel_type_config.element_b_scale]
  198. element_zeropoint = APHRODITEDataTypeNames[
  199. kernel_type_config.element_b_zeropoint]
  200. return (f"{element_a}{element_b}{element_d}"
  201. f"{accumulator}{element_scale}{element_zeropoint}")
  202. # non-unique shorter type_name
  203. def generate_terse_type_signature(kernel_type_config: TypeConfig):
  204. element_a = APHRODITEDataTypeNames[kernel_type_config.element_a]
  205. element_b = APHRODITEDataTypeNames[kernel_type_config.element_b]
  206. return f"{element_a}{element_b}"
  207. def is_power_of_two(n):
  208. return (n != 0) and (n & (n - 1) == 0)
  209. def to_cute_constant(value: List[int]):
  210. def _to_cute_constant(value: int):
  211. if is_power_of_two(value):
  212. return f"_{value}"
  213. else:
  214. return f"Int<{value}>"
  215. if isinstance(value, Iterable):
  216. return [_to_cute_constant(value) for value in value]
  217. else:
  218. return _to_cute_constant(value)
  219. template_globals = {
  220. "DataTypeTag": APHRODITEDataTypeTag,
  221. "KernelScheduleTag": APHRODITEKernelScheduleTag,
  222. "EpilogueScheduleTag": EpilogueScheduleTag,
  223. "TileSchedulerTag": TileSchedulerTag,
  224. "to_cute_constant": to_cute_constant,
  225. "gen_sch_name": generate_terse_schedule_name,
  226. }
  227. def create_template(template_str):
  228. template = jinja2.Template(template_str)
  229. template.globals.update(template_globals)
  230. return template
  231. mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
  232. mm_impl_template = create_template(IMPL_TEMPLATE)
  233. prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
  234. def create_sources(impl_config: ImplConfig, num_impl_files=2):
  235. sources = []
  236. type_name = generate_type_signature(impl_config.type_config)
  237. terse_type_name = generate_terse_type_signature(impl_config.type_config)
  238. sources.append((
  239. f"machete_mm_{terse_type_name}",
  240. mm_dispatch_template.render(type_name=type_name,
  241. type_config=impl_config.type_config,
  242. schedules=impl_config.schedule_configs,
  243. heuristic=impl_config.heuristic),
  244. ))
  245. sources.append((
  246. f"machete_prepack_{terse_type_name}",
  247. prepack_dispatch_template.render(
  248. type_name=type_name,
  249. type_config=impl_config.type_config,
  250. ),
  251. ))
  252. num_schedules = len(impl_config.schedule_configs)
  253. schedules_per_file = math.ceil(num_schedules / num_impl_files)
  254. for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
  255. file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
  256. sources.append((
  257. f"machete_mm_{terse_type_name}_impl_part{part}",
  258. mm_impl_template.render(
  259. type_name=type_name,
  260. type_config=impl_config.type_config,
  261. schedules=file_schedules,
  262. specializations=impl_config.specializations,
  263. ),
  264. ))
  265. return sources
  266. def generate():
  267. # See csrc/quantization/machete/Readme.md, the Codegeneration for more info
  268. # about how this works
  269. SCRIPT_DIR = os.path.dirname(__file__)
  270. schedule_common_params = dict(
  271. kernel_schedule=TmaMI,
  272. epilogue_schedule=TmaCoop,
  273. tile_scheduler=TileSchedulerType.StreamK,
  274. )
  275. # For now we use the same heuristic for all types
  276. # Heuristic is currently tuned for H100s
  277. default_heuristic = [
  278. #### M = 257+
  279. (
  280. "M > 256 && K <= 16384 && N <= 4096",
  281. ScheduleConfig(
  282. tile_shape_mn=(128, 128),
  283. cluster_shape_mnk=(2, 1, 1),
  284. **schedule_common_params # type: ignore
  285. )),
  286. (
  287. "M > 256",
  288. ScheduleConfig(
  289. tile_shape_mn=(128, 256),
  290. cluster_shape_mnk=(2, 1, 1),
  291. **schedule_common_params # type: ignore
  292. )),
  293. #### M = 129-256
  294. (
  295. "M > 128 && K <= 4096 && N <= 4096",
  296. ScheduleConfig(
  297. tile_shape_mn=(128, 64),
  298. cluster_shape_mnk=(2, 1, 1),
  299. **schedule_common_params # type: ignore
  300. )),
  301. (
  302. "M > 128 && K <= 8192 && N <= 8192",
  303. ScheduleConfig(
  304. tile_shape_mn=(128, 128),
  305. cluster_shape_mnk=(2, 1, 1),
  306. **schedule_common_params # type: ignore
  307. )),
  308. (
  309. "M > 128",
  310. ScheduleConfig(
  311. tile_shape_mn=(128, 256),
  312. cluster_shape_mnk=(2, 1, 1),
  313. **schedule_common_params # type: ignore
  314. )),
  315. #### M = 65-128
  316. (
  317. "M > 64 && K <= 4069 && N <= 4069",
  318. ScheduleConfig(
  319. tile_shape_mn=(128, 32),
  320. cluster_shape_mnk=(2, 1, 1),
  321. **schedule_common_params # type: ignore
  322. )),
  323. (
  324. "M > 64 && K <= 4069 && N <= 8192",
  325. ScheduleConfig(
  326. tile_shape_mn=(128, 64),
  327. cluster_shape_mnk=(2, 1, 1),
  328. **schedule_common_params # type: ignore
  329. )),
  330. (
  331. "M > 64 && K >= 8192 && N >= 12288",
  332. ScheduleConfig(
  333. tile_shape_mn=(256, 128),
  334. cluster_shape_mnk=(2, 1, 1),
  335. **schedule_common_params # type: ignore
  336. )),
  337. (
  338. "M > 64",
  339. ScheduleConfig(
  340. tile_shape_mn=(128, 128),
  341. cluster_shape_mnk=(2, 1, 1),
  342. **schedule_common_params # type: ignore
  343. )),
  344. #### M = 33-64
  345. (
  346. "M > 32 && K <= 6144 && N <= 6144",
  347. ScheduleConfig(
  348. tile_shape_mn=(128, 16),
  349. cluster_shape_mnk=(1, 1, 1),
  350. **schedule_common_params # type: ignore
  351. )),
  352. (
  353. "M > 32 && K >= 16384 && N >= 12288",
  354. ScheduleConfig(
  355. tile_shape_mn=(256, 64),
  356. cluster_shape_mnk=(2, 1, 1),
  357. **schedule_common_params # type: ignore
  358. )),
  359. (
  360. "M > 32",
  361. ScheduleConfig(
  362. tile_shape_mn=(128, 64),
  363. cluster_shape_mnk=(2, 1, 1),
  364. **schedule_common_params # type: ignore
  365. )),
  366. #### M = 17-32
  367. (
  368. "M > 16 && K <= 12288 && N <= 8192",
  369. ScheduleConfig(
  370. tile_shape_mn=(128, 32),
  371. cluster_shape_mnk=(2, 1, 1),
  372. **schedule_common_params # type: ignore
  373. )),
  374. (
  375. "M > 16",
  376. ScheduleConfig(
  377. tile_shape_mn=(256, 32),
  378. cluster_shape_mnk=(2, 1, 1),
  379. **schedule_common_params # type: ignore
  380. )),
  381. #### M = 1-16
  382. (
  383. "N >= 26624",
  384. ScheduleConfig(
  385. tile_shape_mn=(256, 16),
  386. cluster_shape_mnk=(1, 1, 1),
  387. **schedule_common_params # type: ignore
  388. )),
  389. (
  390. None,
  391. ScheduleConfig(
  392. tile_shape_mn=(128, 16),
  393. cluster_shape_mnk=(1, 1, 1),
  394. **schedule_common_params # type: ignore
  395. )),
  396. ]
  397. schedules = list(set([x[1] for x in default_heuristic]))
  398. impl_configs = []
  399. GPTQ_kernel_type_configs = list(
  400. (TypeConfig(
  401. element_a=element_a,
  402. element_b=element_b,
  403. element_b_scale=element_a,
  404. element_b_zeropoint=element_a,
  405. element_d=element_a,
  406. accumulator=DataType.f32,
  407. ) for element_b in (APHRODITEDataType.u4b8, APHRODITEDataType.u8b128)
  408. for element_a in (DataType.f16, DataType.bf16)))
  409. GPTQ_kernel_specializations = [
  410. Specialization(with_C=False, with_zeropoints=False, with_scales=True)
  411. ]
  412. impl_configs += [
  413. ImplConfig(x[0], x[1], x[2], x[3])
  414. for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
  415. itertools.repeat(GPTQ_kernel_specializations),
  416. itertools.repeat(default_heuristic))
  417. ]
  418. AWQ_kernel_type_configs = list(
  419. (TypeConfig(
  420. element_a=element_a,
  421. element_b=element_b,
  422. element_b_scale=element_a,
  423. element_b_zeropoint=element_a,
  424. element_d=element_a,
  425. accumulator=DataType.f32,
  426. ) for element_b in (DataType.u4, DataType.u8)
  427. for element_a in (DataType.f16, DataType.bf16)))
  428. AWQ_kernel_specializations = [
  429. Specialization(with_C=False, with_zeropoints=True, with_scales=True)
  430. ]
  431. impl_configs += [
  432. ImplConfig(x[0], x[1], x[2], x[3])
  433. for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
  434. itertools.repeat(AWQ_kernel_specializations),
  435. itertools.repeat(default_heuristic))
  436. ]
  437. output_dir = os.path.join(SCRIPT_DIR, "generated")
  438. # Delete the "generated" directory if it exists
  439. if os.path.exists(output_dir):
  440. shutil.rmtree(output_dir)
  441. # Create the "generated" directory
  442. os.makedirs(output_dir)
  443. # Render each group of configurations into separate files
  444. for impl_config in impl_configs:
  445. for filename, code in create_sources(impl_config):
  446. filepath = os.path.join(output_dir, f"{filename}.cu")
  447. with open(filepath, "w") as output_file:
  448. output_file.write(code)
  449. print(f"Rendered template to {filepath}")
  450. if __name__ == "__main__":
  451. generate()