flash_common.cpp 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. namespace flash {
  6. int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
  7. {
  8. int device;
  9. auto status = hipGetDevice(&device);
  10. if(status != hipSuccess)
  11. return num_splits;
  12. hipDeviceProp_t props{};
  13. status = hipGetDeviceProperties(&props, device);
  14. if(status != hipSuccess)
  15. return num_splits;
  16. // TODO - tile size should match the TileFmhaShape, hardcode for now
  17. const int kM0 = 128;
  18. const int kN1 = hdim_v;
  19. const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
  20. const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
  21. if(num_splits < 1 && p_drop == 0.0f)
  22. return num_splits_heuristic_ck(
  23. batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
  24. return num_splits;
  25. }
  26. } // namespace flash