machete_collective_builder.cuh 1.4 KB

123456789101112131415161718192021222324252627282930313233
  1. #pragma once
  2. #include "cutlass_extensions/aphrodite_collective_builder.cuh"
  3. #include "machete_mainloop.cuh"
  4. namespace cutlass::gemm::collective {
  5. using namespace cute;
  6. struct MacheteKernelTag {};
  7. template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
  8. class ElementPairB_, class GmemLayoutB_, int AlignmentB,
  9. class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
  10. class StageCountType, class KernelScheduleType>
  11. struct APHRODITECollectiveBuilder<
  12. MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
  13. GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
  14. ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
  15. KernelScheduleType,
  16. cute::enable_if_t<(
  17. cute::is_same_v<KernelScheduleType,
  18. KernelTmaWarpSpecializedMixedInput> ||
  19. cute::is_same_v<KernelScheduleType,
  20. KernelTmaWarpSpecializedPingpongMixedInput> ||
  21. cute::is_same_v<KernelScheduleType,
  22. KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
  23. using CollectiveOp = machete::MacheteCollectiveMma<
  24. ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
  25. AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
  26. StageCountType, KernelScheduleType>;
  27. };
  28. }; // namespace cutlass::gemm::collective