tile_size.h 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <tuple>
  6. // Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap}
  7. constexpr std::tuple<int, int, bool, bool> tile_size_fwd_sm90(
  8. int headdim, bool is_causal, bool is_local, int element_size=2,
  9. bool v_colmajor=false, bool paged_kv=false, bool softcap=false) {
  10. if (element_size == 2) {
  11. if (headdim <= 64) {
  12. return {192, 128, true, true};
  13. // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen
  14. // return {192, is_causal || is_local ? 192 : 176, true, false};
  15. } else if (headdim <= 96) {
  16. return {192, is_local || paged_kv ? 128 : 144, false, true};
  17. } else if (headdim <= 128) {
  18. return {128, is_causal || is_local || paged_kv ? 128 : 176, true, true};
  19. // {128, 192, false, false} and {192, 128, false, true} are quite good too
  20. // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS
  21. } else if (headdim <= 192) {
  22. return {128, paged_kv || is_local ? 96 : 112, true, true}; // 128 x 112 hits the limit of smem
  23. } else {
  24. return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem
  25. }
  26. } else {
  27. if (headdim <= 64) {
  28. return {192, 160, true, true};
  29. } else if (headdim <= 96) {
  30. return {192, 128, true, true};
  31. } else if (headdim <= 128) {
  32. return {128, paged_kv ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true};
  33. } else if (headdim <= 192) {
  34. return {128, (paged_kv || softcap) && is_local ? 128 : 160, true, true};
  35. } else {
  36. return {128, is_local ? 64 : 128, true, !paged_kv}; // PagedKV uses more registers so we disabled IntraWGOverlap
  37. }
  38. }
  39. }
  40. // Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs}
  41. constexpr std::tuple<int, int, int, int, bool> tile_size_fwd_sm8x(
  42. bool sm86_or_89, int headdim, bool is_causal, bool is_local, int element_size=2,
  43. bool paged_kv=false, bool varlen_and_split=false,
  44. bool softcap=false, bool append_kv=false) {
  45. if (element_size == 2) {
  46. if (headdim <= 64) {
  47. return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false};
  48. } else if (headdim <= 96) {
  49. return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false};
  50. } else if (headdim <= 128) {
  51. bool const use_8_warps = sm86_or_89 | varlen_and_split;
  52. return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps};
  53. } else if (headdim <= 192) {
  54. bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv;
  55. return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64};
  56. } else {
  57. return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv};
  58. }
  59. } else {
  60. // Placeholder for now
  61. return {128, 64, 8, 2, false};
  62. }
  63. }