1
0

tile_size.h 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <tuple>
  6. // Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap}
  7. constexpr std::tuple<int, int, bool, bool> tile_size_fwd(int headdim, bool is_causal, bool is_local, int element_size=2,
  8. bool v_colmajor=false, bool paged_kv=false, bool softcap=false) {
  9. if (element_size == 2) {
  10. if (headdim <= 64) {
  11. return {192, 128, true, true};
  12. // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen
  13. // return {192, is_causal || is_local ? 192 : 176, true, false};
  14. } else if (headdim <= 96) {
  15. return {192, is_local ? 128 : 144, false, true};
  16. } else if (headdim <= 128) {
  17. return {128, is_causal || is_local ? 128 : 176, true, true};
  18. // {128, 192, false, false} and {192, 128, false, true} are quite good too
  19. // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS
  20. } else if (headdim <= 192) {
  21. return {128, paged_kv || is_local ? 96 : 112, true, true}; // 128 x 112 hits the limit of smem
  22. } else {
  23. return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem
  24. }
  25. } else {
  26. if (headdim <= 64) {
  27. return {192, 160, true, true};
  28. } else if (headdim <= 96) {
  29. return {192, 128, true, true};
  30. } else if (headdim <= 128) {
  31. return {128, v_colmajor || (paged_kv && is_local) ? 192 : 224, true, true};
  32. } else if (headdim <= 192) {
  33. return {128, paged_kv && is_local ? 128 : 160, true, true};
  34. } else {
  35. return {128, is_local ? 64 : 128, true, !paged_kv}; // PagedKV uses more registers so we disabled IntraWGOverlap
  36. }
  37. }
  38. }