named_barrier.hpp 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/arch/barrier.h"
  6. namespace flash {
  7. ////////////////////////////////////////////////////////////////////////////////////////////////////
  8. // cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work
  9. // for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80.
  10. CUTLASS_DEVICE
  11. static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) {
  12. static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);
  13. uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;
  14. asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
  15. cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
  16. }
  17. CUTLASS_DEVICE
  18. static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {
  19. uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);
  20. asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
  21. cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
  22. }
  23. CUTLASS_DEVICE
  24. static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) {
  25. static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);
  26. uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;
  27. cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
  28. asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
  29. }
  30. CUTLASS_DEVICE
  31. static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {
  32. uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);
  33. cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
  34. asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
  35. }
  36. ////////////////////////////////////////////////////////////////////////////////////////////////////
  37. // Enumerates the reserved named barriers to avoid potential conflicts
  38. enum class FwdNamedBarriers {
  39. QueryEmpty = 0,
  40. ProducerWG = 1,
  41. TileCountSmemEmpty = 2,
  42. TileCountSmemFull = 3,
  43. WarpSchedulerWG1 = 4,
  44. WarpSchedulerWG2 = 5,
  45. WarpSchedulerWG3 = 6,
  46. AppendKV = 7,
  47. QueryRotated = 8,
  48. };
  49. enum class BwdNamedBarriers {
  50. KVEmpty = 0,
  51. PdS = 1,
  52. // This needs to match FwdNamedBarriers::TileCountSmemEmpty since TileScheduler uses it
  53. TileCountSmemEmpty = 2,
  54. TileCountSmemFull = 3,
  55. dQEmptyWG1 = 4,
  56. dQEmptyWG2 = 5,
  57. dQEmptyWG3 = 6,
  58. dQFullWG1 = 7,
  59. dQFullWG2 = 8,
  60. dQFullWG3 = 9,
  61. };
  62. } // flash