moe_config.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import json
  2. import os
  3. import sys
  4. import torch
  5. import torch.nn.functional as F
  6. import triton
  7. from aphrodite.modeling.layers.fused_moe import fused_moe, get_config_file_name
  8. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  9. def main():
  10. method = fused_moe
  11. for bs in [
  12. 1,
  13. 2,
  14. 4,
  15. 8,
  16. 16,
  17. 24,
  18. 32,
  19. 48,
  20. 64,
  21. 96,
  22. 128,
  23. 256,
  24. 512,
  25. 1024,
  26. 1536,
  27. 2048,
  28. 3072,
  29. 4096,
  30. ]:
  31. run_grid(bs, method=method)
  32. def run_grid(bs, method):
  33. d_model = 4096
  34. num_total_experts = 8
  35. top_k = 2
  36. tp_size = 2
  37. model_intermediate_size = 14336
  38. num_layers = 32
  39. num_calls = 100
  40. num_warmup_trials = 1
  41. num_trials = 1
  42. configs = []
  43. if bs <= 16:
  44. BLOCK_SIZES_M = [16]
  45. elif bs <= 32:
  46. BLOCK_SIZES_M = [16, 32]
  47. elif bs <= 64:
  48. BLOCK_SIZES_M = [16, 32, 64]
  49. elif bs <= 128:
  50. BLOCK_SIZES_M = [16, 32, 64, 128]
  51. else:
  52. BLOCK_SIZES_M = [16, 32, 64, 128, 256]
  53. for block_size_n in [32, 64, 128, 256]:
  54. for block_size_m in BLOCK_SIZES_M:
  55. for block_size_k in [64, 128, 256]:
  56. for group_size_m in [1, 16, 32, 64]:
  57. for num_warps in [4, 8]:
  58. configs.append({
  59. "BLOCK_SIZE_M": block_size_m,
  60. "BLOCK_SIZE_N": block_size_n,
  61. "BLOCK_SIZE_K": block_size_k,
  62. "GROUP_SIZE_M": group_size_m,
  63. "num_warps": num_warps,
  64. "num_stages": 4,
  65. })
  66. best_config = None
  67. best_time_us = 1e20
  68. for config in configs:
  69. print(f"{tp_size=} {bs=}")
  70. print(f"{config}")
  71. # warmup
  72. print("warming up")
  73. try:
  74. for _ in range(num_warmup_trials):
  75. run_timing(
  76. num_calls=num_calls,
  77. bs=bs,
  78. d_model=d_model,
  79. num_total_experts=num_total_experts,
  80. top_k=top_k,
  81. tp_size=tp_size,
  82. model_intermediate_size=model_intermediate_size,
  83. method=method,
  84. config=config,
  85. )
  86. except triton.runtime.autotuner.OutOfResources:
  87. continue
  88. # trial
  89. print("benchmarking")
  90. for _ in range(num_trials):
  91. kernel_dur_ms = run_timing(
  92. num_calls=num_calls,
  93. bs=bs,
  94. d_model=d_model,
  95. num_total_experts=num_total_experts,
  96. top_k=top_k,
  97. tp_size=tp_size,
  98. model_intermediate_size=model_intermediate_size,
  99. method=method,
  100. config=config,
  101. )
  102. kernel_dur_us = 1000 * kernel_dur_ms
  103. model_dur_ms = kernel_dur_ms * num_layers
  104. if kernel_dur_us < best_time_us:
  105. best_config = config
  106. best_time_us = kernel_dur_us
  107. print(f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
  108. f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
  109. f"{d_model=} {model_intermediate_size=} {num_layers=}")
  110. print("best_time_us", best_time_us)
  111. print("best_config", best_config)
  112. # holds Dict[str, Dict[str, int]]
  113. filename = get_config_file_name(num_total_experts,
  114. model_intermediate_size // tp_size)
  115. print(f"writing config to file {filename}")
  116. existing_content = {}
  117. if os.path.exists(filename):
  118. with open(filename, "r") as f:
  119. existing_content = json.load(f)
  120. existing_content[str(bs)] = best_config
  121. with open(filename, "w") as f:
  122. json.dump(existing_content, f, indent=4)
  123. f.write("\n")
  124. def run_timing(
  125. num_calls: int,
  126. bs: int,
  127. d_model: int,
  128. num_total_experts: int,
  129. top_k: int,
  130. tp_size: int,
  131. model_intermediate_size: int,
  132. method,
  133. config,
  134. ) -> float:
  135. shard_intermediate_size = model_intermediate_size // tp_size
  136. hidden_states = torch.rand(
  137. (bs, d_model),
  138. device="cuda:0",
  139. dtype=torch.bfloat16,
  140. )
  141. ws = torch.rand(
  142. (num_total_experts, 2 * shard_intermediate_size, d_model),
  143. device=hidden_states.device,
  144. dtype=hidden_states.dtype,
  145. )
  146. w2s = torch.rand(
  147. (num_total_experts, d_model, shard_intermediate_size),
  148. device=hidden_states.device,
  149. dtype=hidden_states.dtype,
  150. )
  151. gating_output = F.softmax(
  152. torch.rand(
  153. (num_calls, bs, num_total_experts),
  154. device=hidden_states.device,
  155. dtype=torch.float32,
  156. ),
  157. dim=-1,
  158. )
  159. start_event = torch.cuda.Event(enable_timing=True)
  160. end_event = torch.cuda.Event(enable_timing=True)
  161. start_event.record()
  162. for i in range(num_calls):
  163. hidden_states = method(
  164. hidden_states=hidden_states,
  165. w1=ws,
  166. w2=w2s,
  167. gating_output=gating_output[i],
  168. topk=2,
  169. renormalize=True,
  170. inplace=True,
  171. override_config=config,
  172. )
  173. end_event.record()
  174. end_event.synchronize()
  175. dur_ms = start_event.elapsed_time(end_event) / num_calls
  176. return dur_ms
  177. if __name__ == "__main__":
  178. sys.exit(main())