named_barrier.hpp 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/arch/barrier.h"
  6. namespace flash {
  7. ////////////////////////////////////////////////////////////////////////////////////////////////////
  8. // cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work
  9. // for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80.
  10. CUTLASS_DEVICE
  11. static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) {
  12. static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);
  13. uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;
  14. asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
  15. cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
  16. }
  17. CUTLASS_DEVICE
  18. static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {
  19. uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);
  20. asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
  21. cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
  22. }
  23. CUTLASS_DEVICE
  24. static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) {
  25. static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);
  26. uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;
  27. cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
  28. asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
  29. }
  30. CUTLASS_DEVICE
  31. static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {
  32. uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);
  33. cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
  34. asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
  35. }
  36. ////////////////////////////////////////////////////////////////////////////////////////////////////
  37. // Enumerates the reserved named barriers to avoid potential conflicts
  38. enum class FwdNamedBarriers {
  39. QueryEmpty = 0,
  40. ProducerWG = 1,
  41. TileCountSmemEmpty = 2,
  42. TileCountSmemFull = 3,
  43. WarpSchedulerWG1 = 4,
  44. WarpSchedulerWG2 = 5,
  45. WarpSchedulerWG3 = 6,
  46. AppendKV = 7,
  47. QueryRotated = 8,
  48. PFull = 9,
  49. PEmpty = 6, // HACK: PEmpty is only used when we don't have 3 WGs
  50. };
  51. enum class BwdNamedBarriers {
  52. KVEmpty = 0,
  53. PdS = 1,
  54. // This needs to match FwdNamedBarriers::TileCountSmemEmpty since TileScheduler uses it
  55. TileCountSmemEmpty = 2,
  56. TileCountSmemFull = 3,
  57. dQEmptyWG1 = 4,
  58. dQEmptyWG2 = 5,
  59. dQEmptyWG3 = 6,
  60. dQFullWG1 = 7,
  61. dQFullWG2 = 8,
  62. dQFullWG3 = 9,
  63. };
  64. } // flash