tile_scheduler_bwd.hpp 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cutlass/fast_math.h"
  6. #include "cutlass/arch/barrier.h"
  7. #include "named_barrier.hpp"
  8. namespace flash {
  9. ///////////////////////////////////////////////////////////////////////////////
  10. class SingleTileSchedulerBwd {
  11. public:
  12. using SharedStorage = int;
  13. // Host side kernel arguments
  14. struct Arguments {
  15. int const num_blocks_m, num_head, num_batch;
  16. int* const tile_count_semaphore = nullptr;
  17. int* const cu_seqlens = nullptr;
  18. };
  19. // Device side kernel params
  20. struct Params {
  21. int const num_blocks_m, num_head, num_batch;
  22. };
  23. static Params
  24. to_underlying_arguments(Arguments const& args) {
  25. return {args.num_blocks_m, args.num_head, args.num_batch};
  26. }
  27. static dim3
  28. get_grid_shape(Params const& params, int num_sm) {
  29. return {uint32_t(params.num_blocks_m), uint32_t(params.num_head), uint32_t(params.num_batch)};
  30. }
  31. struct WorkTileInfo {
  32. int M_idx = 0;
  33. int H_idx = 0;
  34. int B_idx = 0;
  35. bool is_valid_tile = false;
  36. CUTLASS_DEVICE
  37. bool
  38. is_valid(Params const& params) const {
  39. return is_valid_tile;
  40. }
  41. CUTLASS_DEVICE
  42. cute::tuple<int32_t, int32_t, int32_t>
  43. get_block_coord(Params const& params) const {
  44. return {M_idx, H_idx, B_idx};
  45. }
  46. };
  47. CUTLASS_DEVICE
  48. SingleTileSchedulerBwd(SharedStorage* const smem_scheduler) { }
  49. template<bool IsProducer=false>
  50. CUTLASS_DEVICE
  51. WorkTileInfo
  52. get_initial_work(Params const& params) const {
  53. return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
  54. }
  55. CUTLASS_DEVICE
  56. void
  57. init_consumer() const {}
  58. CUTLASS_DEVICE
  59. void
  60. prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
  61. template<bool IsProducer=false>
  62. CUTLASS_DEVICE
  63. WorkTileInfo
  64. get_next_work(Params const& params, WorkTileInfo const& current_work) const {
  65. return {-1, -1, -1, false};
  66. }
  67. };
  68. } // namespace flash