aphrodite_cutlass_library_extension.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import enum
  2. from typing import Dict, Union
  3. from cutlass_library import *
  4. #
  5. # Extend cutlass library with custom types, and missing values
  6. #
  7. class APHRODITEDataType(enum.Enum):
  8. u4b8 = enum_auto()
  9. u8b128 = enum_auto()
  10. class MixedInputKernelScheduleType(enum.Enum):
  11. TmaWarpSpecializedMixedInput = enum_auto()
  12. TmaWarpSpecializedPingpongMixedInput = enum_auto()
  13. TmaWarpSpecializedCooperativeMixedInput = enum_auto()
  14. APHRODITEDataTypeNames: Dict[Union[APHRODITEDataType, DataType], str] = {
  15. **DataTypeNames, # type: ignore
  16. **{
  17. APHRODITEDataType.u4b8: "u4b8",
  18. APHRODITEDataType.u8b128: "u8b128",
  19. }
  20. }
  21. APHRODITEDataTypeTag: Dict[Union[APHRODITEDataType, DataType], str] = {
  22. **DataTypeTag, # type: ignore
  23. **{
  24. APHRODITEDataType.u4b8: "cutlass::aphrodite_uint4b8_t",
  25. APHRODITEDataType.u8b128: "cutlass::aphrodite_uint8b128_t",
  26. }
  27. }
  28. APHRODITEKernelScheduleTag: Dict[Union[
  29. MixedInputKernelScheduleType, KernelScheduleType], str] = {
  30. **KernelScheduleTag, # type: ignore
  31. **{
  32. MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
  33. "cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
  34. MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
  35. "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
  36. MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
  37. "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
  38. }
  39. }