123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- #pragma once
- #include <vector>
- inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
-
- if (varlen_q) return true;
-
- auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
- float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
- float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
- return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
- };
- inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
-
- if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
-
- if (num_n_blocks <= 4) { return 1; }
- max_splits = std::min({max_splits, num_SMs, num_n_blocks});
- float max_efficiency = 0.f;
- std::vector<float> efficiency;
- efficiency.reserve(max_splits);
- for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
- float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
- float eff = n_waves / ceil(n_waves);
-
- if (eff > max_efficiency) { max_efficiency = eff; }
- efficiency.push_back(eff);
- }
- for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
- if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
-
- return num_splits;
- }
- }
- return 1;
- }
|