moe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. import argparse
  2. import time
  3. from datetime import datetime
  4. from typing import Any, Dict, List, Tuple, TypedDict
  5. import ray
  6. import torch
  7. import triton
  8. from ray.experimental.tqdm_ray import tqdm
  9. from transformers import AutoConfig
  10. from aphrodite.common.utils import FlexibleArgumentParser
  11. from aphrodite.modeling.layers.fused_moe.fused_moe import *
  12. class BenchmarkConfig(TypedDict):
  13. BLOCK_SIZE_M: int
  14. BLOCK_SIZE_N: int
  15. BLOCK_SIZE_K: int
  16. GROUP_SIZE_M: int
  17. num_warps: int
  18. num_stages: int
  19. def benchmark_config(
  20. config: BenchmarkConfig,
  21. num_tokens: int,
  22. num_experts: int,
  23. shard_intermediate_size: int,
  24. hidden_size: int,
  25. topk: int,
  26. dtype: torch.dtype,
  27. use_fp8: bool,
  28. num_iters: int = 100,
  29. ) -> float:
  30. init_dtype = torch.float16 if use_fp8 else dtype
  31. x = torch.randn(num_tokens, hidden_size, dtype=dtype)
  32. w1 = torch.randn(num_experts,
  33. shard_intermediate_size,
  34. hidden_size,
  35. dtype=init_dtype)
  36. w2 = torch.randn(num_experts,
  37. hidden_size,
  38. shard_intermediate_size // 2,
  39. dtype=init_dtype)
  40. gating_output = torch.randn(num_iters,
  41. num_tokens,
  42. num_experts,
  43. dtype=torch.float32)
  44. w1_scale = None
  45. w2_scale = None
  46. a1_scale = None
  47. a2_scale = None
  48. if use_fp8:
  49. w1_scale = torch.randn(num_experts, dtype=torch.float32)
  50. w2_scale = torch.randn(num_experts, dtype=torch.float32)
  51. a1_scale = torch.randn(1, dtype=torch.float32)
  52. a2_scale = torch.randn(1, dtype=torch.float32)
  53. w1 = w1.to(torch.float8_e4m3fn)
  54. w2 = w2.to(torch.float8_e4m3fn)
  55. input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
  56. def prepare(i: int):
  57. input_gating.copy_(gating_output[i])
  58. def run():
  59. fused_moe(
  60. x,
  61. w1,
  62. w2,
  63. input_gating,
  64. topk,
  65. renormalize=True,
  66. inplace=True,
  67. override_config=config,
  68. use_fp8=use_fp8,
  69. w1_scale=w1_scale,
  70. w2_scale=w2_scale,
  71. a1_scale=a1_scale,
  72. a2_scale=a2_scale,
  73. )
  74. # JIT compilation & warmup
  75. run()
  76. torch.cuda.synchronize()
  77. # Capture 10 invocations with CUDA graph
  78. graph = torch.cuda.CUDAGraph()
  79. with torch.cuda.graph(graph):
  80. for _ in range(10):
  81. run()
  82. torch.cuda.synchronize()
  83. # Warmup
  84. for _ in range(5):
  85. graph.replay()
  86. torch.cuda.synchronize()
  87. start_event = torch.cuda.Event(enable_timing=True)
  88. end_event = torch.cuda.Event(enable_timing=True)
  89. latencies: List[float] = []
  90. for i in range(num_iters):
  91. prepare(i)
  92. torch.cuda.synchronize()
  93. start_event.record()
  94. graph.replay()
  95. end_event.record()
  96. end_event.synchronize()
  97. latencies.append(start_event.elapsed_time(end_event))
  98. avg = sum(latencies) / (num_iters * 10) * 1000 # us
  99. graph.reset()
  100. return avg
  101. def get_configs_compute_bound() -> List[Dict[str, int]]:
  102. # Reduced search space for faster tuning.
  103. # TODO(woosuk): Increase the search space and use a performance model to
  104. # prune the search space.
  105. configs: List[BenchmarkConfig] = []
  106. for num_stages in [2, 3, 4, 5]:
  107. for block_m in [16, 32, 64, 128, 256]:
  108. for block_k in [64, 128, 256]:
  109. for block_n in [32, 64, 128, 256]:
  110. for num_warps in [4, 8]:
  111. for group_size in [1, 16, 32, 64]:
  112. configs.append({
  113. "BLOCK_SIZE_M": block_m,
  114. "BLOCK_SIZE_N": block_n,
  115. "BLOCK_SIZE_K": block_k,
  116. "GROUP_SIZE_M": group_size,
  117. "num_warps": num_warps,
  118. "num_stages": num_stages,
  119. })
  120. return configs
  121. @ray.remote(num_gpus=1)
  122. class BenchmarkWorker:
  123. def __init__(self, seed: int) -> None:
  124. torch.set_default_device("cuda")
  125. torch.cuda.manual_seed_all(seed)
  126. self.seed = seed
  127. def benchmark(
  128. self,
  129. num_tokens: int,
  130. num_experts: int,
  131. shard_intermediate_size: int,
  132. hidden_size: int,
  133. topk: int,
  134. dtype: torch.dtype,
  135. use_fp8: bool,
  136. ) -> Tuple[Dict[str, int], float]:
  137. torch.cuda.manual_seed_all(self.seed)
  138. dtype_str = "float8" if use_fp8 else None
  139. # NOTE(woosuk): The current naming convention uses w2.shape[2], which
  140. # is the intermediate size after silu_and_mul.
  141. op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
  142. dtype_str)
  143. if op_config is None:
  144. config = get_default_config(num_tokens, num_experts,
  145. shard_intermediate_size, hidden_size,
  146. topk, dtype_str)
  147. else:
  148. config = op_config[min(op_config.keys(),
  149. key=lambda x: abs(x - num_tokens))]
  150. kernel_time = benchmark_config(config, num_tokens, num_experts,
  151. shard_intermediate_size, hidden_size,
  152. topk, dtype, use_fp8)
  153. return config, kernel_time
  154. def tune(
  155. self,
  156. num_tokens: int,
  157. num_experts: int,
  158. shard_intermediate_size: int,
  159. hidden_size: int,
  160. topk: int,
  161. dtype: torch.dtype,
  162. use_fp8: bool,
  163. search_space: List[BenchmarkConfig],
  164. ) -> BenchmarkConfig:
  165. best_config = None
  166. best_time = float("inf")
  167. for config in tqdm(search_space):
  168. try:
  169. kernel_time = benchmark_config(config,
  170. num_tokens,
  171. num_experts,
  172. shard_intermediate_size,
  173. hidden_size,
  174. topk,
  175. dtype,
  176. use_fp8,
  177. num_iters=10)
  178. except triton.runtime.autotuner.OutOfResources:
  179. # Some configurations may be invalid and fail to compile.
  180. continue
  181. if kernel_time < best_time:
  182. best_time = kernel_time
  183. best_config = config
  184. now = datetime.now()
  185. print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
  186. assert best_config is not None
  187. return best_config
  188. def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
  189. return {
  190. "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
  191. "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
  192. "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
  193. "GROUP_SIZE_M": config["GROUP_SIZE_M"],
  194. "num_warps": config["num_warps"],
  195. "num_stages": config["num_stages"],
  196. }
  197. def save_configs(
  198. configs: Dict[int, BenchmarkConfig],
  199. num_experts: int,
  200. shard_intermediate_size: int,
  201. hidden_size: int,
  202. topk: int,
  203. dtype: torch.dtype,
  204. use_fp8: bool,
  205. ) -> None:
  206. dtype_str = "float8" if use_fp8 else None
  207. # NOTE(woosuk): The current naming convention uses w2.shape[2], which
  208. # is the intermediate size after silu_and_mul.
  209. filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
  210. dtype_str)
  211. print(f"Writing best config to {filename}...")
  212. with open(filename, "w") as f:
  213. json.dump(configs, f, indent=4)
  214. f.write("\n")
  215. def main(args: argparse.Namespace):
  216. print(args)
  217. config = AutoConfig.from_pretrained(args.model)
  218. if config.architectures[0] == "DbrxForCausalLM":
  219. E = config.ffn_config.moe_num_experts
  220. topk = config.ffn_config.moe_top_k
  221. intermediate_size = config.ffn_config.ffn_hidden_size
  222. shard_intermediate_size = 2 * intermediate_size // args.tp_size
  223. else:
  224. # Default: Mixtral.
  225. E = config.num_local_experts
  226. topk = config.num_experts_per_tok
  227. intermediate_size = config.intermediate_size
  228. shard_intermediate_size = 2 * intermediate_size // args.tp_size
  229. hidden_size = config.hidden_size
  230. dtype = config.torch_dtype
  231. use_fp8 = args.dtype == "fp8"
  232. if args.batch_size is None:
  233. batch_sizes = [
  234. 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
  235. 2048, 3072, 4096
  236. ]
  237. else:
  238. batch_sizes = [args.batch_size]
  239. ray.init()
  240. num_gpus = int(ray.available_resources()["GPU"])
  241. workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
  242. def _distribute(method: str, inputs: List[Any]) -> List[Any]:
  243. outputs = []
  244. worker_idx = 0
  245. for input_args in inputs:
  246. worker = workers[worker_idx]
  247. worker_method = getattr(worker, method)
  248. output = worker_method.remote(*input_args)
  249. outputs.append(output)
  250. worker_idx = (worker_idx + 1) % num_gpus
  251. return ray.get(outputs)
  252. if args.tune:
  253. search_space = get_configs_compute_bound()
  254. print(f"Start tuning over {len(search_space)} configurations...")
  255. start = time.time()
  256. configs = _distribute(
  257. "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
  258. topk, dtype, use_fp8, search_space)
  259. for batch_size in batch_sizes])
  260. best_configs = {
  261. M: sort_config(config)
  262. for M, config in zip(batch_sizes, configs)
  263. }
  264. save_configs(best_configs, E, shard_intermediate_size, hidden_size,
  265. topk, dtype, use_fp8)
  266. end = time.time()
  267. print(f"Tuning took {end - start:.2f} seconds")
  268. else:
  269. outputs = _distribute("benchmark",
  270. [(batch_size, E, shard_intermediate_size,
  271. hidden_size, topk, dtype, use_fp8)
  272. for batch_size in batch_sizes])
  273. for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
  274. print(f"Batch size: {batch_size}, config: {config}")
  275. print(f"Kernel time: {kernel_time:.2f} us")
  276. if __name__ == "__main__":
  277. parser = FlexibleArgumentParser()
  278. parser.add_argument("--model",
  279. type=str,
  280. default="mistralai/Mixtral-8x7B-Instruct-v0.1")
  281. parser.add_argument("--tp-size", "-tp", type=int, default=2)
  282. parser.add_argument("--dtype",
  283. type=str,
  284. choices=["auto", "fp8"],
  285. default="auto")
  286. parser.add_argument("--seed", type=int, default=0)
  287. parser.add_argument("--batch-size", type=int, required=False)
  288. parser.add_argument("--tune", action="store_true")
  289. args = parser.parse_args()
  290. main(args)