heuristics.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <vector>
  6. inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
  7. // If varlen, we don't actually know seqlen_q but only max_seqlen_q.
  8. if (varlen_q) return true;
  9. // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
  10. auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
  11. float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
  12. float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
  13. return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
  14. };
  15. // Find the number of splits that maximizes the occupancy. For example, if we have
  16. // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
  17. // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
  18. // splits as that would incur more HBM reads/writes.
  19. // So we find the best efficiency, then find the smallest number of splits that gets 85%
  20. // of the best efficiency.
  21. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
  22. // If we have enough to almost fill the SMs, then just use 1 split
  23. if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
  24. // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
  25. if (num_n_blocks <= 4) { return 1; }
  26. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  27. float max_efficiency = 0.f;
  28. std::vector<float> efficiency;
  29. efficiency.reserve(max_splits);
  30. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  31. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  32. float eff = n_waves / ceil(n_waves);
  33. // printf("num_splits = %d, eff = %f\n", num_splits, eff);
  34. if (eff > max_efficiency) { max_efficiency = eff; }
  35. efficiency.push_back(eff);
  36. }
  37. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  38. if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
  39. // printf("num_splits chosen = %d\n", num_splits);
  40. return num_splits;
  41. }
  42. }
  43. return 1;
  44. }