12345678910111213141516171819202122232425262728293031323334 |
- /******************************************************************************
- * Copyright (c) 2024, Tri Dao.
- ******************************************************************************/
- #include "flash_common.hpp"
- namespace flash {
- int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
- {
- int device;
- auto status = hipGetDevice(&device);
- if(status != hipSuccess)
- return num_splits;
- hipDeviceProp_t props{};
- status = hipGetDeviceProperties(&props, device);
- if(status != hipSuccess)
- return num_splits;
- // TODO - tile size should match the TileFmhaShape, hardcode for now
- const int kM0 = 128;
- const int kN1 = hdim_v;
- const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
- const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
- if(num_splits < 1 && p_drop == 0.0f)
- return num_splits_heuristic_ck(
- batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
- return num_splits;
- }
- } // namespace flash
|