123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530 |
- import itertools
- import math
- import os
- import shutil
- from collections.abc import Iterable
- from dataclasses import dataclass
- from typing import List, Optional, Tuple, Union
- import jinja2
- # yapf conflicts with isort for this block
- # yapf: disable
- from aphrodite_cutlass_library_extension import (APHRODITEDataType,
- APHRODITEDataTypeNames,
- APHRODITEKernelScheduleTag,
- DataType, EpilogueScheduleTag,
- EpilogueScheduleType,
- MixedInputKernelScheduleType,
- TileSchedulerTag,
- TileSchedulerType)
- # yapf: enable
- #
- # Generator templating
- #
- #include "../machete_mm_launcher.cuh"
- namespace machete {
- using GemmDispatcher_ = GemmDispatcher<
- {{DataTypeTag[type_config.element_a]}}, // ElementA
- {{DataTypeTag[type_config.element_b]}}, // ElementB
- {{DataTypeTag[type_config.element_d]}}, // ElementD
- {{DataTypeTag[type_config.accumulator]}}, // Accumulator
- {{DataTypeTag[type_config.element_b_scale]}}, // Scales
- {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
- {% for s in schedules %}extern torch::Tensor
- impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
- {% endfor %}
- template <>
- torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
- [[maybe_unused]] auto M = args.A.size(0);
- [[maybe_unused]] auto N = args.B.size(1);
- [[maybe_unused]] auto K = args.A.size(1);
- if (!args.schedule) {
- {%- for cond, s in heuristic %}
- {%if cond is not none%}if ({{cond}})
- {%- else %}else
- {%- endif %}
- return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
- }
- {% for s in schedules %}
- if (*args.schedule == "{{ gen_sch_name(s) }}") {
- return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
- }
- {% endfor %}
- TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
- "schedule = ", *args.schedule);
- }
- template <>
- std::vector<std::string> GemmDispatcher_::supported_schedules() {
- return {
- {% for s in schedules -%}
- "{{ gen_sch_name(s) }}"{{ ",
- " if not loop.last }}{%- endfor %}
- };
- }
- }; // namespace machete
- """
- #include "../machete_mm_launcher.cuh"
- namespace machete {
- template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
- using Kernel = MacheteKernelTemplate<
- {{DataTypeTag[type_config.element_a]}}, // ElementA
- {{DataTypeTag[type_config.element_b]}}, // ElementB
- {{DataTypeTag[type_config.element_d]}}, // ElementD
- {{DataTypeTag[type_config.accumulator]}}, // Accumulator
- {{DataTypeTag[type_config.element_b_scale]}}, // Scales
- {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
- cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
- Config, with_C, with_scales, with_zeropoints>;
- {% for sch in schedules %}
- {% set schedule_name = gen_sch_name(sch) -%}
- struct sch_{{schedule_name}} {
- using TileShapeNM = Shape<{{
- to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
- using ClusterShape = Shape<{{
- to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
- // TODO: Reimplement
- // using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
- using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
- using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
- using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
- };
- torch::Tensor
- impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
- bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
- with_zeropoints = args.zeros.has_value();
- {% for s in specializations %}
- if (with_C == {{s.with_C|lower}}
- && with_zeropoints == {{s.with_zeropoints|lower}}
- && with_scales == {{s.with_scales|lower}}) {
- return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
- {{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
- }{% endfor %}
- false, "for the sake of compile times and binary size machete_mm(..) is "
- " not implemented for with_C=", with_C, ", with_scales=", with_scales,
- ", with_zeropoints=", with_zeropoints,
- " (for {{type_name}}_sch_{{schedule_name}})");
- }
- {% endfor %}
- }; // namespace machete
- """
- #include "../machete_prepack_launcher.cuh"
- namespace machete {
- using PrepackBDispatcher_ = PrepackBDispatcher<
- {{DataTypeTag[type_config.element_a]}}, // ElementA
- {{DataTypeTag[type_config.element_b]}}, // ElementB
- {{DataTypeTag[type_config.element_d]}}, // ElementD
- {{DataTypeTag[type_config.accumulator]}}, // Accumulator
- {{DataTypeTag[type_config.element_b_scale]}}, // Scales
- {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
- using PrepackedLayoutB = PrepackedLayoutBTemplate<
- {{DataTypeTag[type_config.element_a]}}, // ElementA
- {{DataTypeTag[type_config.element_b]}}, // ElementB
- {{DataTypeTag[type_config.element_d]}}, // ElementD
- {{DataTypeTag[type_config.accumulator]}}, // Accumulator
- cutlass::layout::ColumnMajor,
- cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
- template <>
- torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
- return prepack_impl<PrepackedLayoutB>(B);
- }
- }; // namespace machete
- """
- TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
- TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
- @dataclass(frozen=True)
- class ScheduleConfig:
- tile_shape_mn: Tuple[int, int]
- cluster_shape_mnk: Tuple[int, int, int]
- kernel_schedule: MixedInputKernelScheduleType
- epilogue_schedule: EpilogueScheduleType
- tile_scheduler: TileSchedulerType
- @dataclass
- class TypeConfig:
- element_a: DataType
- element_b: Union[DataType, APHRODITEDataType]
- element_b_scale: DataType
- element_b_zeropoint: DataType
- element_d: DataType
- accumulator: DataType
- @dataclass
- class Specialization:
- with_C: bool
- with_zeropoints: bool
- with_scales: bool
- @dataclass
- class ImplConfig:
- type_config: TypeConfig
- schedule_configs: List[ScheduleConfig]
- specializations: List[Specialization]
- heuristic: List[Tuple[Optional[str], ScheduleConfig]]
- def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
- tile_shape = (
- f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
- )
- cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
- f"x{schedule_config.cluster_shape_mnk[1]}" +
- f"x{schedule_config.cluster_shape_mnk[2]}")
- kernel_schedule = APHRODITEKernelScheduleTag[
- schedule_config.kernel_schedule]\
- .split("::")[-1]
- epilogue_schedule = EpilogueScheduleTag[
- schedule_config.epilogue_schedule].split("::")[-1]
- tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
- .split("::")[-1]
- return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
- f"_{epilogue_schedule}_{tile_scheduler}")
- # mostly unique shorter schedule_name
- def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
- kernel_terse_names_replace = {
- "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
- "TmaWarpSpecializedCooperative_": "TmaCoop_",
- "StreamKScheduler": "streamK",
- }
- schedule_name = generate_schedule_name(schedule_config)
- for orig, terse in kernel_terse_names_replace.items():
- schedule_name = schedule_name.replace(orig, terse)
- return schedule_name
- # unique type_name
- def generate_type_signature(kernel_type_config: TypeConfig):
- element_a = APHRODITEDataTypeNames[kernel_type_config.element_a]
- element_b = APHRODITEDataTypeNames[kernel_type_config.element_b]
- element_d = APHRODITEDataTypeNames[kernel_type_config.element_d]
- accumulator = APHRODITEDataTypeNames[kernel_type_config.accumulator]
- element_scale = APHRODITEDataTypeNames[kernel_type_config.element_b_scale]
- element_zeropoint = APHRODITEDataTypeNames[
- kernel_type_config.element_b_zeropoint]
- return (f"{element_a}{element_b}{element_d}"
- f"{accumulator}{element_scale}{element_zeropoint}")
- # non-unique shorter type_name
- def generate_terse_type_signature(kernel_type_config: TypeConfig):
- element_a = APHRODITEDataTypeNames[kernel_type_config.element_a]
- element_b = APHRODITEDataTypeNames[kernel_type_config.element_b]
- return f"{element_a}{element_b}"
- def is_power_of_two(n):
- return (n != 0) and (n & (n - 1) == 0)
- def to_cute_constant(value: List[int]):
- def _to_cute_constant(value: int):
- if is_power_of_two(value):
- return f"_{value}"
- else:
- return f"Int<{value}>"
- if isinstance(value, Iterable):
- return [_to_cute_constant(value) for value in value]
- else:
- return _to_cute_constant(value)
- template_globals = {
- "DataTypeTag": APHRODITEDataTypeTag,
- "KernelScheduleTag": APHRODITEKernelScheduleTag,
- "EpilogueScheduleTag": EpilogueScheduleTag,
- "TileSchedulerTag": TileSchedulerTag,
- "to_cute_constant": to_cute_constant,
- "gen_sch_name": generate_terse_schedule_name,
- }
- def create_template(template_str):
- template = jinja2.Template(template_str)
- template.globals.update(template_globals)
- return template
- mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
- mm_impl_template = create_template(IMPL_TEMPLATE)
- prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
- def create_sources(impl_config: ImplConfig, num_impl_files=2):
- sources = []
- type_name = generate_type_signature(impl_config.type_config)
- terse_type_name = generate_terse_type_signature(impl_config.type_config)
- sources.append((
- f"machete_mm_{terse_type_name}",
- mm_dispatch_template.render(type_name=type_name,
- type_config=impl_config.type_config,
- schedules=impl_config.schedule_configs,
- heuristic=impl_config.heuristic),
- ))
- sources.append((
- f"machete_prepack_{terse_type_name}",
- prepack_dispatch_template.render(
- type_name=type_name,
- type_config=impl_config.type_config,
- ),
- ))
- num_schedules = len(impl_config.schedule_configs)
- schedules_per_file = math.ceil(num_schedules / num_impl_files)
- for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
- file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
- sources.append((
- f"machete_mm_{terse_type_name}_impl_part{part}",
- mm_impl_template.render(
- type_name=type_name,
- type_config=impl_config.type_config,
- schedules=file_schedules,
- specializations=impl_config.specializations,
- ),
- ))
- return sources
- def generate():
- # See csrc/quantization/machete/Readme.md, the Codegeneration for more info
- # about how this works
- SCRIPT_DIR = os.path.dirname(__file__)
- schedule_common_params = dict(
- kernel_schedule=TmaMI,
- epilogue_schedule=TmaCoop,
- tile_scheduler=TileSchedulerType.StreamK,
- )
- # For now we use the same heuristic for all types
- # Heuristic is currently tuned for H100s
- default_heuristic = [
- #### M = 257+
- (
- "M > 256 && K <= 16384 && N <= 4096",
- ScheduleConfig(
- tile_shape_mn=(128, 128),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 256",
- ScheduleConfig(
- tile_shape_mn=(128, 256),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- #### M = 129-256
- (
- "M > 128 && K <= 4096 && N <= 4096",
- ScheduleConfig(
- tile_shape_mn=(128, 64),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 128 && K <= 8192 && N <= 8192",
- ScheduleConfig(
- tile_shape_mn=(128, 128),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 128",
- ScheduleConfig(
- tile_shape_mn=(128, 256),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- #### M = 65-128
- (
- "M > 64 && K <= 4069 && N <= 4069",
- ScheduleConfig(
- tile_shape_mn=(128, 32),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 64 && K <= 4069 && N <= 8192",
- ScheduleConfig(
- tile_shape_mn=(128, 64),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 64 && K >= 8192 && N >= 12288",
- ScheduleConfig(
- tile_shape_mn=(256, 128),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 64",
- ScheduleConfig(
- tile_shape_mn=(128, 128),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- #### M = 33-64
- (
- "M > 32 && K <= 6144 && N <= 6144",
- ScheduleConfig(
- tile_shape_mn=(128, 16),
- cluster_shape_mnk=(1, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 32 && K >= 16384 && N >= 12288",
- ScheduleConfig(
- tile_shape_mn=(256, 64),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 32",
- ScheduleConfig(
- tile_shape_mn=(128, 64),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- #### M = 17-32
- (
- "M > 16 && K <= 12288 && N <= 8192",
- ScheduleConfig(
- tile_shape_mn=(128, 32),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- "M > 16",
- ScheduleConfig(
- tile_shape_mn=(256, 32),
- cluster_shape_mnk=(2, 1, 1),
- **schedule_common_params # type: ignore
- )),
- #### M = 1-16
- (
- "N >= 26624",
- ScheduleConfig(
- tile_shape_mn=(256, 16),
- cluster_shape_mnk=(1, 1, 1),
- **schedule_common_params # type: ignore
- )),
- (
- None,
- ScheduleConfig(
- tile_shape_mn=(128, 16),
- cluster_shape_mnk=(1, 1, 1),
- **schedule_common_params # type: ignore
- )),
- ]
- schedules = list(set([x[1] for x in default_heuristic]))
- impl_configs = []
- GPTQ_kernel_type_configs = list(
- (TypeConfig(
- element_a=element_a,
- element_b=element_b,
- element_b_scale=element_a,
- element_b_zeropoint=element_a,
- element_d=element_a,
- accumulator=DataType.f32,
- ) for element_b in (APHRODITEDataType.u4b8, APHRODITEDataType.u8b128)
- for element_a in (DataType.f16, DataType.bf16)))
- GPTQ_kernel_specializations = [
- Specialization(with_C=False, with_zeropoints=False, with_scales=True)
- ]
- impl_configs += [
- ImplConfig(x[0], x[1], x[2], x[3])
- for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
- itertools.repeat(GPTQ_kernel_specializations),
- itertools.repeat(default_heuristic))
- ]
- AWQ_kernel_type_configs = list(
- (TypeConfig(
- element_a=element_a,
- element_b=element_b,
- element_b_scale=element_a,
- element_b_zeropoint=element_a,
- element_d=element_a,
- accumulator=DataType.f32,
- ) for element_b in (DataType.u4, DataType.u8)
- for element_a in (DataType.f16, DataType.bf16)))
- AWQ_kernel_specializations = [
- Specialization(with_C=False, with_zeropoints=True, with_scales=True)
- ]
- impl_configs += [
- ImplConfig(x[0], x[1], x[2], x[3])
- for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
- itertools.repeat(AWQ_kernel_specializations),
- itertools.repeat(default_heuristic))
- ]
- output_dir = os.path.join(SCRIPT_DIR, "generated")
- # Delete the "generated" directory if it exists
- if os.path.exists(output_dir):
- shutil.rmtree(output_dir)
- # Create the "generated" directory
- os.makedirs(output_dir)
- # Render each group of configurations into separate files
- for impl_config in impl_configs:
- for filename, code in create_sources(impl_config):
- filepath = os.path.join(output_dir, f"{filename}.cu")
- with open(filepath, "w") as output_file:
- output_file.write(code)
- print(f"Rendered template to {filepath}")
- if __name__ == "__main__":
- generate()