test_punica_sizes.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. """
  2. This script is mainly used to tests various hidden_sizes. We have collected the
  3. hidden_sizes included in the LoRA models currently supported by vLLM. It tests
  4. whether the corresponding Triton kernel can run normally when tensor parallelism
  5. is set to [1, 2, 4, 8, 16, 32, 64].
  6. """
  7. import random
  8. from unittest.mock import patch
  9. import pytest
  10. import torch
  11. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  12. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  13. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  14. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  15. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  16. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  17. from aphrodite.triton_utils.libentry import LibEntry
  18. from .utils import (generate_data, generate_data_for_expand_nslices,
  19. ref_torch_groupgemm)
  20. HIDDEN_SIZES = [
  21. 128,
  22. 256,
  23. 512,
  24. 896,
  25. 1024,
  26. 1152,
  27. 1216,
  28. 1280,
  29. 1536,
  30. 1664,
  31. 2048,
  32. 2240,
  33. 2304,
  34. 2368,
  35. 2432,
  36. 2560,
  37. 2752,
  38. 3072,
  39. 3328,
  40. 3456,
  41. 3584,
  42. 3712,
  43. 4096,
  44. 4480,
  45. 4608,
  46. 4736,
  47. 4864,
  48. 5120,
  49. 5504,
  50. 5632,
  51. 5888,
  52. 6144,
  53. 6400,
  54. 6848,
  55. 6912,
  56. 7168,
  57. 7424,
  58. 8192,
  59. 8960,
  60. 9216,
  61. 9472,
  62. 10240,
  63. 11008,
  64. 11264,
  65. 13824,
  66. 14336,
  67. 14784,
  68. 14848,
  69. 15360,
  70. 18944,
  71. 22016,
  72. 22528,
  73. 24576,
  74. 27392,
  75. 27648,
  76. 29568,
  77. 29696,
  78. 32000,
  79. 32256,
  80. 32512,
  81. 32768,
  82. 33024,
  83. 36864,
  84. 43264,
  85. 49152,
  86. 49408,
  87. 60544,
  88. 60672,
  89. 64000,
  90. 64256,
  91. 102400,
  92. 102656,
  93. 128000,
  94. 128256,
  95. ]
  96. #The size of TP
  97. divisibility = [1, 2, 8, 16, 64]
  98. all_hidden_size = []
  99. for div in divisibility:
  100. for hidden_size in HIDDEN_SIZES:
  101. all_hidden_size.append(hidden_size // div)
  102. HIDDEN_SIZES = list(set(all_hidden_size))
  103. BATCHES = [4]
  104. NUM_LORA = [4]
  105. DTYPES = [torch.float16, torch.bfloat16]
  106. MAX_RANKS = [32]
  107. SCALES = [0.5]
  108. SEED = [0]
  109. CUDA_DEVICES = [f"cuda:{0}"]
  110. def assert_close(a, b):
  111. rtol, atol = {
  112. torch.float16: (6e-2, 6e-2),
  113. torch.bfloat16: (6e-2, 6e-2),
  114. torch.float32: (1e-2, 1e-2),
  115. }[a.dtype]
  116. torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
  117. @pytest.mark.parametrize("batches", BATCHES)
  118. @pytest.mark.parametrize("num_loras", NUM_LORA)
  119. @pytest.mark.parametrize("rank", MAX_RANKS)
  120. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  121. @pytest.mark.parametrize("scaling", SCALES)
  122. @pytest.mark.parametrize("dtype", DTYPES)
  123. @pytest.mark.parametrize("op_type", ["shrink", "expand"])
  124. @pytest.mark.parametrize("seed", SEED)
  125. @pytest.mark.parametrize("device", CUDA_DEVICES)
  126. def test_punica_sgmv(
  127. batches: int,
  128. num_loras: int,
  129. rank: int,
  130. hidden_size: int,
  131. scaling: float,
  132. dtype: torch.dtype,
  133. op_type: str,
  134. seed: int,
  135. device: str,
  136. ):
  137. random.seed(seed)
  138. torch.set_default_device(device)
  139. torch.random.manual_seed(seed)
  140. if torch.cuda.is_available():
  141. torch.cuda.manual_seed(seed)
  142. seq_length = 128
  143. (
  144. inputs_tensor,
  145. lora_weights,
  146. our_out_tensor,
  147. ref_out_tensor,
  148. b_seq_start_loc,
  149. lora_indices_tensor,
  150. seq_len_tensor,
  151. indices,
  152. ) = generate_data(
  153. batches,
  154. hidden_size,
  155. num_loras,
  156. rank,
  157. seq_length,
  158. dtype,
  159. op_type,
  160. device,
  161. )
  162. max_seq_length = seq_len_tensor.max()
  163. if isinstance(max_seq_length, tuple):
  164. max_seq_length = max_seq_length[0].item()
  165. else:
  166. max_seq_length = max_seq_length.item()
  167. if op_type == "shrink":
  168. sgmv_shrink(
  169. inputs_tensor,
  170. lora_weights,
  171. our_out_tensor,
  172. b_seq_start_loc,
  173. seq_len_tensor,
  174. lora_indices_tensor,
  175. batches,
  176. max_seq_length,
  177. scaling,
  178. )
  179. else:
  180. sgmv_expand(
  181. inputs_tensor,
  182. lora_weights,
  183. our_out_tensor,
  184. b_seq_start_loc,
  185. seq_len_tensor,
  186. lora_indices_tensor,
  187. batches,
  188. max_seq_length,
  189. add_inputs=True,
  190. )
  191. ref_torch_groupgemm(
  192. ref_out_tensor,
  193. inputs_tensor,
  194. lora_weights,
  195. lora_indices_tensor,
  196. seq_len_tensor,
  197. batches,
  198. scaling if op_type == "shrink" else 1.0,
  199. op_type,
  200. )
  201. if op_type == "shrink":
  202. ref_out_tensor = ref_out_tensor.to(torch.float32)
  203. assert_close(our_out_tensor, ref_out_tensor)
  204. @pytest.mark.parametrize("batches", BATCHES)
  205. @pytest.mark.parametrize("num_loras", NUM_LORA)
  206. @pytest.mark.parametrize("rank", MAX_RANKS)
  207. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  208. @pytest.mark.parametrize("scaling", SCALES)
  209. @pytest.mark.parametrize("dtype", DTYPES)
  210. @pytest.mark.parametrize("op_type", ["shrink", "expand"])
  211. @pytest.mark.parametrize("seed", SEED)
  212. @pytest.mark.parametrize("device", CUDA_DEVICES)
  213. def test_punica_bgmv(
  214. batches: int,
  215. num_loras: int,
  216. rank: int,
  217. hidden_size: int,
  218. scaling: float,
  219. dtype: torch.dtype,
  220. op_type: str,
  221. seed: int,
  222. device: str,
  223. ):
  224. from aphrodite.lora.ops.bgmv_expand import _bgmv_expand_kernel
  225. from aphrodite.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
  226. random.seed(seed)
  227. torch.set_default_device(device)
  228. torch.random.manual_seed(seed)
  229. if torch.cuda.is_available():
  230. torch.cuda.manual_seed(seed)
  231. seq_length = 1
  232. (
  233. inputs_tensor,
  234. lora_weights,
  235. our_out_tensor,
  236. ref_out_tensor,
  237. b_seq_start_loc,
  238. lora_indices_tensor,
  239. seq_len_tensor,
  240. indices,
  241. ) = generate_data(
  242. batches,
  243. hidden_size,
  244. num_loras,
  245. rank,
  246. seq_length,
  247. dtype,
  248. op_type,
  249. device,
  250. )
  251. if op_type == "shrink":
  252. # The current _bgmv_shrink_kernel does not require the libentry
  253. # decoration. The purpose of adding this patch is to test the
  254. # correctness of libentry.
  255. with patch(
  256. "aphrodite.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
  257. LibEntry(_bgmv_shrink_kernel),
  258. ):
  259. bgmv_shrink(
  260. inputs_tensor,
  261. lora_weights,
  262. our_out_tensor,
  263. indices,
  264. scaling,
  265. )
  266. else:
  267. # ditto
  268. with patch(
  269. "aphrodite.lora.ops.bgmv_expand._bgmv_expand_kernel",
  270. LibEntry(_bgmv_expand_kernel),
  271. ):
  272. bgmv_expand(
  273. inputs_tensor,
  274. lora_weights,
  275. our_out_tensor,
  276. indices,
  277. add_inputs=True,
  278. )
  279. ref_torch_groupgemm(
  280. ref_out_tensor,
  281. inputs_tensor,
  282. lora_weights,
  283. lora_indices_tensor,
  284. seq_len_tensor,
  285. batches,
  286. scaling if op_type == "shrink" else 1.0,
  287. op_type,
  288. )
  289. if op_type == "shrink":
  290. ref_out_tensor = ref_out_tensor.to(torch.float32)
  291. assert_close(our_out_tensor, ref_out_tensor)
  292. @pytest.mark.parametrize("batches", BATCHES)
  293. @pytest.mark.parametrize("num_loras", NUM_LORA)
  294. @pytest.mark.parametrize("rank", MAX_RANKS)
  295. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  296. @pytest.mark.parametrize("nslices", [2, 3])
  297. @pytest.mark.parametrize("dtype", DTYPES)
  298. @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
  299. @pytest.mark.parametrize("seed", SEED)
  300. @pytest.mark.parametrize("device", CUDA_DEVICES)
  301. def test_punica_expand_nslices(
  302. batches: int,
  303. num_loras: int,
  304. rank: int,
  305. hidden_size: int,
  306. nslices: int,
  307. dtype: torch.dtype,
  308. op_type: str,
  309. seed: int,
  310. device: str,
  311. ):
  312. from aphrodite.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
  313. random.seed(seed)
  314. torch.set_default_device(device)
  315. torch.random.manual_seed(seed)
  316. if torch.cuda.is_available():
  317. torch.cuda.manual_seed(seed)
  318. seq_length = 128 if op_type == "sgmv" else 1
  319. (
  320. inputs_tensor,
  321. lora_weights_lst,
  322. our_outputs,
  323. ref_outputs,
  324. b_seq_start_loc,
  325. lora_indices_tensor,
  326. seq_len_tensor,
  327. indices,
  328. ) = generate_data_for_expand_nslices(
  329. batches,
  330. hidden_size,
  331. num_loras,
  332. rank,
  333. seq_length,
  334. dtype,
  335. nslices,
  336. device,
  337. )
  338. max_seq_length = seq_len_tensor.max()
  339. if isinstance(max_seq_length, tuple):
  340. max_seq_length = max_seq_length[0].item()
  341. else:
  342. max_seq_length = max_seq_length.item()
  343. slice_offset = 0
  344. for index in range(nslices):
  345. lora_weights = lora_weights_lst[index]
  346. if op_type == "sgmv":
  347. sgmv_expand_slice(
  348. inputs_tensor,
  349. lora_weights,
  350. our_outputs,
  351. b_seq_start_loc,
  352. seq_len_tensor,
  353. lora_indices_tensor,
  354. batches,
  355. max_seq_length,
  356. slice_offset,
  357. hidden_size,
  358. add_inputs=True,
  359. )
  360. else:
  361. # The current _bgmv_expand_slice_kernel does not require the
  362. # libentry decoration. The purpose of adding this patch is to test
  363. # the correctness of libentry.
  364. with patch(
  365. "aphrodite.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
  366. LibEntry(_bgmv_expand_slice_kernel),
  367. ):
  368. bgmv_expand_slice(
  369. inputs_tensor,
  370. lora_weights,
  371. our_outputs,
  372. indices,
  373. slice_offset,
  374. slice_size=hidden_size,
  375. add_inputs=True,
  376. )
  377. ref_torch_groupgemm(
  378. ref_outputs[:, slice_offset:slice_offset + hidden_size],
  379. inputs_tensor,
  380. lora_weights,
  381. lora_indices_tensor,
  382. seq_len_tensor,
  383. batches,
  384. 1.0,
  385. op_type="expand",
  386. )
  387. slice_offset += hidden_size
  388. assert_close(our_outputs, ref_outputs)