1
0

marlin.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. from typing import List
  2. import torch
  3. import torch.utils.benchmark as benchmark
  4. from benchmark_shapes import WEIGHT_SHAPES
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.common.utils import FlexibleArgumentParser
  7. from aphrodite.quantization.gptq_marlin_24 import (
  8. GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
  9. GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
  12. GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
  13. from aphrodite.quantization.utils.marlin_utils_test import (MarlinWorkspace,
  14. marlin_quantize)
  15. from aphrodite.quantization.utils.marlin_utils_test_24 import (
  16. marlin_24_quantize)
  17. from aphrodite.quantization.utils.quant_utils import (gptq_pack,
  18. quantize_weights,
  19. sort_weights)
  20. DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
  21. DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
  22. ACT_ORDER_OPTS = [False, True]
  23. K_FULL_OPTS = [False, True]
  24. def bench_run(results: List[benchmark.Measurement], model: str,
  25. act_order: bool, is_k_full: bool, num_bits: int, group_size: int,
  26. size_m: int, size_k: int, size_n: int):
  27. label = "Quant Matmul"
  28. sub_label = ("{}, act={} k_full={}, b={}, g={}, "
  29. "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
  30. group_size, size_m, size_k, size_n))
  31. print(f"Testing: {sub_label}")
  32. a = torch.randn(size_m, size_k).to(torch.half).cuda()
  33. b = torch.rand(size_k, size_n).to(torch.half).cuda()
  34. a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
  35. # Marlin quant
  36. (
  37. marlin_w_ref,
  38. marlin_q_w,
  39. marlin_s,
  40. marlin_g_idx,
  41. marlin_sort_indices,
  42. marlin_rand_perm,
  43. ) = marlin_quantize(b, num_bits, group_size, act_order)
  44. # Marlin_24 quant
  45. (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
  46. marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
  47. # GPTQ quant
  48. (w_ref, q_w, s, g_idx,
  49. rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
  50. q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
  51. # For act_order, sort the "weights" and "g_idx"
  52. # so that group ids are increasing
  53. repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
  54. if act_order:
  55. (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
  56. # Prepare
  57. marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
  58. GPTQ_MARLIN_MAX_PARALLEL)
  59. marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
  60. GPTQ_MARLIN_24_MAX_PARALLEL)
  61. globals = {
  62. # Gen params
  63. "num_bits": num_bits,
  64. "group_size": group_size,
  65. "size_m": size_m,
  66. "size_n": size_n,
  67. "size_k": size_k,
  68. "a": a,
  69. "a_tmp": a_tmp,
  70. # Marlin params
  71. "marlin_w_ref": marlin_w_ref,
  72. "marlin_q_w": marlin_q_w,
  73. "marlin_s": marlin_s,
  74. "marlin_g_idx": marlin_g_idx,
  75. "marlin_sort_indices": marlin_sort_indices,
  76. "marlin_rand_perm": marlin_rand_perm,
  77. "marlin_workspace": marlin_workspace,
  78. "is_k_full": is_k_full,
  79. # Marlin_24 params
  80. "marlin_24_w_ref": marlin_24_w_ref,
  81. "marlin_24_q_w_comp": marlin_24_q_w_comp,
  82. "marlin_24_meta": marlin_24_meta,
  83. "marlin_24_s": marlin_24_s,
  84. "marlin_24_workspace": marlin_24_workspace,
  85. # GPTQ params
  86. "q_w_gptq": q_w_gptq,
  87. "repack_sort_indices": repack_sort_indices,
  88. # Kernels
  89. "gptq_marlin_gemm": ops.gptq_marlin_gemm,
  90. "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
  91. "gptq_marlin_repack": ops.gptq_marlin_repack,
  92. }
  93. min_run_time = 1
  94. # Warmup pytorch
  95. for i in range(5):
  96. torch.matmul(a, marlin_w_ref)
  97. results.append(
  98. benchmark.Timer(
  99. stmt="torch.matmul(a, marlin_w_ref)",
  100. globals=globals,
  101. label=label,
  102. sub_label=sub_label,
  103. description="pytorch_gemm",
  104. ).blocked_autorange(min_run_time=min_run_time))
  105. results.append(
  106. benchmark.Timer(
  107. stmt=
  108. "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
  109. globals=globals,
  110. label=label,
  111. sub_label=sub_label,
  112. description="gptq_marlin_gemm",
  113. ).blocked_autorange(min_run_time=min_run_time))
  114. if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
  115. and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
  116. results.append(
  117. benchmark.Timer(
  118. stmt=
  119. "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
  120. globals=globals,
  121. label=label,
  122. sub_label=sub_label,
  123. description="gptq_marlin_24_gemm",
  124. ).blocked_autorange(min_run_time=min_run_time))
  125. results.append(
  126. benchmark.Timer(
  127. stmt=
  128. "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
  129. globals=globals,
  130. label=label,
  131. sub_label=sub_label,
  132. description="gptq_marlin_repack",
  133. ).blocked_autorange(min_run_time=min_run_time))
  134. def main(args):
  135. print("Benchmarking models:")
  136. for i, model in enumerate(args.models):
  137. print(f"[{i}] {model}")
  138. results: List[benchmark.Measurement] = []
  139. for model in args.models:
  140. for layer in WEIGHT_SHAPES[model]:
  141. size_k = layer[0]
  142. size_n = layer[1]
  143. if len(args.limit_k) > 0 and size_k not in args.limit_k:
  144. continue
  145. if len(args.limit_n) > 0 and size_n not in args.limit_n:
  146. continue
  147. for act_order in ACT_ORDER_OPTS:
  148. if len(args.limit_act_order
  149. ) > 0 and act_order not in args.limit_act_order:
  150. continue
  151. for is_k_full in K_FULL_OPTS:
  152. if len(args.limit_k_full
  153. ) > 0 and is_k_full not in args.limit_k_full:
  154. continue
  155. for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
  156. if len(args.limit_num_bits
  157. ) > 0 and num_bits not in args.limit_num_bits:
  158. continue
  159. for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
  160. if len(
  161. args.limit_group_size
  162. ) > 0 and group_size not in args.limit_group_size:
  163. continue
  164. # For act_order, the group_size must be less than
  165. # size_k
  166. if act_order and (group_size == size_k
  167. or group_size == -1):
  168. continue
  169. for size_m in args.batch_sizes:
  170. bench_run(results, model, act_order, is_k_full,
  171. num_bits, group_size, size_m, size_k,
  172. size_n)
  173. compare = benchmark.Compare(results)
  174. compare.print()
  175. # For quick benchmarking use:
  176. # python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
  177. #
  178. if __name__ == "__main__":
  179. parser = FlexibleArgumentParser(
  180. description="Benchmark Marlin across specified models/shapes/batches")
  181. parser.add_argument(
  182. "--models",
  183. nargs="+",
  184. type=str,
  185. default=DEFAULT_MODELS,
  186. choices=WEIGHT_SHAPES.keys(),
  187. )
  188. parser.add_argument("--batch-sizes",
  189. nargs="+",
  190. type=int,
  191. default=DEFAULT_BATCH_SIZES)
  192. parser.add_argument("--limit-k", nargs="+", type=int, default=[])
  193. parser.add_argument("--limit-n", nargs="+", type=int, default=[])
  194. parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
  195. parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
  196. parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
  197. parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
  198. args = parser.parse_args()
  199. main(args)