12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- import enum
- from typing import Dict, Union
- from cutlass_library import *
- #
- # Extend cutlass library with custom types, and missing values
- #
- class APHRODITEDataType(enum.Enum):
- u4b8 = enum_auto()
- u8b128 = enum_auto()
- class MixedInputKernelScheduleType(enum.Enum):
- TmaWarpSpecializedMixedInput = enum_auto()
- TmaWarpSpecializedPingpongMixedInput = enum_auto()
- TmaWarpSpecializedCooperativeMixedInput = enum_auto()
- APHRODITEDataTypeNames: Dict[Union[APHRODITEDataType, DataType], str] = {
- **DataTypeNames, # type: ignore
- **{
- APHRODITEDataType.u4b8: "u4b8",
- APHRODITEDataType.u8b128: "u8b128",
- }
- }
- APHRODITEDataTypeTag: Dict[Union[APHRODITEDataType, DataType], str] = {
- **DataTypeTag, # type: ignore
- **{
- APHRODITEDataType.u4b8: "cutlass::aphrodite_uint4b8_t",
- APHRODITEDataType.u8b128: "cutlass::aphrodite_uint8b128_t",
- }
- }
- APHRODITEKernelScheduleTag: Dict[Union[
- MixedInputKernelScheduleType, KernelScheduleType], str] = {
- **KernelScheduleTag, # type: ignore
- **{
- MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
- "cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
- MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
- "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
- MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
- "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
- }
- }
|