test_punica_sizes.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. """
  2. This script is mainly used to tests various hidden_sizes. We have collected the
  3. hidden_sizes included in the LoRA models currently supported by vLLM. It tests
  4. whether the corresponding Triton kernel can run normally when tensor parallelism
  5. is set to [1, 2, 4, 8, 16, 32, 64].
  6. """
  7. import random
  8. from unittest.mock import patch
  9. import pytest
  10. import torch
  11. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  12. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  13. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  14. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  15. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  16. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  17. from aphrodite.triton_utils.libentry import LibEntry
  18. from .utils import (generate_data, generate_data_for_expand_nslices,
  19. ref_torch_groupgemm)
  20. HIDDEN_SIZES = [
  21. 128,
  22. 256,
  23. 512,
  24. 896,
  25. 1024,
  26. 1152,
  27. 1216,
  28. 1280,
  29. 1536,
  30. 1664,
  31. 2048,
  32. 2240,
  33. 2304,
  34. 2368,
  35. 2432,
  36. 2560,
  37. 2752,
  38. 3072,
  39. 3328,
  40. 3456,
  41. 3584,
  42. 3712,
  43. 4096,
  44. 4480,
  45. 4608,
  46. 4736,
  47. 4864,
  48. 5120,
  49. 5504,
  50. 5632,
  51. 5888,
  52. 6144,
  53. 6400,
  54. 6848,
  55. 6912,
  56. 7168,
  57. 7424,
  58. 8192,
  59. 8960,
  60. 9216,
  61. 9472,
  62. 10240,
  63. 11008,
  64. 11264,
  65. 13824,
  66. 14336,
  67. 14784,
  68. 14848,
  69. 15360,
  70. 18944,
  71. 22016,
  72. 22528,
  73. 24576,
  74. 27392,
  75. 27648,
  76. 29568,
  77. 29696,
  78. 32000,
  79. 32256,
  80. 32512,
  81. 32768,
  82. 33024,
  83. 36864,
  84. 43264,
  85. 49152,
  86. 49408,
  87. 60544,
  88. 60672,
  89. 64000,
  90. 64256,
  91. 102400,
  92. 102656,
  93. 128000,
  94. 128256,
  95. ]
  96. #The size of TP
  97. divisibility = [1, 2, 8, 16, 64]
  98. all_hidden_size = []
  99. for div in divisibility:
  100. for hidden_size in HIDDEN_SIZES:
  101. all_hidden_size.append(hidden_size // div)
  102. HIDDEN_SIZES = list(set(all_hidden_size))
  103. BATCHES = [4]
  104. NUM_LORA = [4]
  105. DTYPES = [torch.float16, torch.bfloat16]
  106. MAX_RANKS = [32]
  107. SCALES = [0.5]
  108. SEED = [0]
  109. CUDA_DEVICES = [f"cuda:{0}"]
  110. def assert_close(a, b):
  111. rtol, atol = {
  112. torch.float16: (6e-2, 6e-2),
  113. torch.bfloat16: (6e-2, 6e-2),
  114. torch.float32: (1e-2, 1e-2),
  115. }[a.dtype]
  116. torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
  117. @pytest.mark.parametrize("batches", BATCHES)
  118. @pytest.mark.parametrize("num_loras", NUM_LORA)
  119. @pytest.mark.parametrize("rank", MAX_RANKS)
  120. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  121. @pytest.mark.parametrize("scaling", SCALES)
  122. @pytest.mark.parametrize("dtype", DTYPES)
  123. @pytest.mark.parametrize("op_type", ["shrink", "expand"])
  124. @pytest.mark.parametrize("seed", SEED)
  125. @pytest.mark.parametrize("device", CUDA_DEVICES)
  126. def test_punica_sgmv(
  127. batches: int,
  128. num_loras: int,
  129. rank: int,
  130. hidden_size: int,
  131. scaling: float,
  132. dtype: torch.dtype,
  133. op_type: str,
  134. seed: int,
  135. device: str,
  136. ):
  137. random.seed(seed)
  138. torch.set_default_device(device)
  139. torch.random.manual_seed(seed)
  140. if torch.cuda.is_available():
  141. torch.cuda.manual_seed(seed)
  142. seq_length = 128
  143. (
  144. inputs_tensor,
  145. lora_weights,
  146. our_out_tensor,
  147. ref_out_tensor,
  148. b_seq_start_loc,
  149. lora_indices_tensor,
  150. seq_len_tensor,
  151. indices,
  152. ) = generate_data(
  153. batches,
  154. hidden_size,
  155. num_loras,
  156. rank,
  157. seq_length,
  158. dtype,
  159. op_type,
  160. device,
  161. )
  162. max_seq_length = seq_len_tensor.max()
  163. token_nums = seq_len_tensor.sum().item()
  164. if isinstance(max_seq_length, tuple):
  165. max_seq_length = max_seq_length[0].item()
  166. else:
  167. max_seq_length = max_seq_length.item()
  168. if op_type == "shrink":
  169. sgmv_shrink(
  170. inputs_tensor,
  171. lora_weights,
  172. our_out_tensor,
  173. b_seq_start_loc,
  174. seq_len_tensor,
  175. lora_indices_tensor,
  176. batches,
  177. max_seq_length,
  178. token_nums,
  179. scaling,
  180. )
  181. else:
  182. sgmv_expand(
  183. inputs_tensor,
  184. lora_weights,
  185. our_out_tensor,
  186. b_seq_start_loc,
  187. seq_len_tensor,
  188. lora_indices_tensor,
  189. batches,
  190. max_seq_length,
  191. token_nums,
  192. add_inputs=True,
  193. )
  194. ref_torch_groupgemm(
  195. ref_out_tensor,
  196. inputs_tensor,
  197. lora_weights,
  198. lora_indices_tensor,
  199. seq_len_tensor,
  200. batches,
  201. scaling if op_type == "shrink" else 1.0,
  202. op_type,
  203. )
  204. if op_type == "shrink":
  205. ref_out_tensor = ref_out_tensor.to(torch.float32)
  206. assert_close(our_out_tensor, ref_out_tensor)
  207. @pytest.mark.parametrize("batches", BATCHES)
  208. @pytest.mark.parametrize("num_loras", NUM_LORA)
  209. @pytest.mark.parametrize("rank", MAX_RANKS)
  210. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  211. @pytest.mark.parametrize("scaling", SCALES)
  212. @pytest.mark.parametrize("dtype", DTYPES)
  213. @pytest.mark.parametrize("op_type", ["shrink", "expand"])
  214. @pytest.mark.parametrize("seed", SEED)
  215. @pytest.mark.parametrize("device", CUDA_DEVICES)
  216. def test_punica_bgmv(
  217. batches: int,
  218. num_loras: int,
  219. rank: int,
  220. hidden_size: int,
  221. scaling: float,
  222. dtype: torch.dtype,
  223. op_type: str,
  224. seed: int,
  225. device: str,
  226. ):
  227. from aphrodite.lora.ops.bgmv_expand import _bgmv_expand_kernel
  228. from aphrodite.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
  229. random.seed(seed)
  230. torch.set_default_device(device)
  231. torch.random.manual_seed(seed)
  232. if torch.cuda.is_available():
  233. torch.cuda.manual_seed(seed)
  234. seq_length = 1
  235. (
  236. inputs_tensor,
  237. lora_weights,
  238. our_out_tensor,
  239. ref_out_tensor,
  240. b_seq_start_loc,
  241. lora_indices_tensor,
  242. seq_len_tensor,
  243. indices,
  244. ) = generate_data(
  245. batches,
  246. hidden_size,
  247. num_loras,
  248. rank,
  249. seq_length,
  250. dtype,
  251. op_type,
  252. device,
  253. )
  254. if op_type == "shrink":
  255. # The current _bgmv_shrink_kernel does not require the libentry
  256. # decoration. The purpose of adding this patch is to test the
  257. # correctness of libentry.
  258. with patch(
  259. "aphrodite.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
  260. LibEntry(_bgmv_shrink_kernel),
  261. ):
  262. bgmv_shrink(
  263. inputs_tensor,
  264. lora_weights,
  265. our_out_tensor,
  266. indices,
  267. scaling,
  268. )
  269. else:
  270. # ditto
  271. with patch(
  272. "aphrodite.lora.ops.bgmv_expand._bgmv_expand_kernel",
  273. LibEntry(_bgmv_expand_kernel),
  274. ):
  275. bgmv_expand(
  276. inputs_tensor,
  277. lora_weights,
  278. our_out_tensor,
  279. indices,
  280. add_inputs=True,
  281. )
  282. ref_torch_groupgemm(
  283. ref_out_tensor,
  284. inputs_tensor,
  285. lora_weights,
  286. lora_indices_tensor,
  287. seq_len_tensor,
  288. batches,
  289. scaling if op_type == "shrink" else 1.0,
  290. op_type,
  291. )
  292. if op_type == "shrink":
  293. ref_out_tensor = ref_out_tensor.to(torch.float32)
  294. assert_close(our_out_tensor, ref_out_tensor)
  295. @pytest.mark.parametrize("batches", BATCHES)
  296. @pytest.mark.parametrize("num_loras", NUM_LORA)
  297. @pytest.mark.parametrize("rank", MAX_RANKS)
  298. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  299. @pytest.mark.parametrize("nslices", [2, 3])
  300. @pytest.mark.parametrize("dtype", DTYPES)
  301. @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
  302. @pytest.mark.parametrize("seed", SEED)
  303. @pytest.mark.parametrize("device", CUDA_DEVICES)
  304. def test_punica_expand_nslices(
  305. batches: int,
  306. num_loras: int,
  307. rank: int,
  308. hidden_size: int,
  309. nslices: int,
  310. dtype: torch.dtype,
  311. op_type: str,
  312. seed: int,
  313. device: str,
  314. ):
  315. from aphrodite.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
  316. random.seed(seed)
  317. torch.set_default_device(device)
  318. torch.random.manual_seed(seed)
  319. if torch.cuda.is_available():
  320. torch.cuda.manual_seed(seed)
  321. seq_length = 128 if op_type == "sgmv" else 1
  322. (
  323. inputs_tensor,
  324. lora_weights_lst,
  325. our_outputs,
  326. ref_outputs,
  327. b_seq_start_loc,
  328. lora_indices_tensor,
  329. seq_len_tensor,
  330. indices,
  331. ) = generate_data_for_expand_nslices(
  332. batches,
  333. hidden_size,
  334. num_loras,
  335. rank,
  336. seq_length,
  337. dtype,
  338. nslices,
  339. device,
  340. )
  341. max_seq_length = seq_len_tensor.max()
  342. token_nums = seq_len_tensor.sum().item()
  343. if isinstance(max_seq_length, tuple):
  344. max_seq_length = max_seq_length[0].item()
  345. else:
  346. max_seq_length = max_seq_length.item()
  347. slice_offset = 0
  348. for index in range(nslices):
  349. lora_weights = lora_weights_lst[index]
  350. if op_type == "sgmv":
  351. sgmv_expand_slice(
  352. inputs_tensor,
  353. lora_weights,
  354. our_outputs,
  355. b_seq_start_loc,
  356. seq_len_tensor,
  357. lora_indices_tensor,
  358. batches,
  359. max_seq_length,
  360. token_nums,
  361. slice_offset,
  362. hidden_size,
  363. add_inputs=True,
  364. )
  365. else:
  366. # The current _bgmv_expand_slice_kernel does not require the
  367. # libentry decoration. The purpose of adding this patch is to test
  368. # the correctness of libentry.
  369. with patch(
  370. "aphrodite.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
  371. LibEntry(_bgmv_expand_slice_kernel),
  372. ):
  373. bgmv_expand_slice(
  374. inputs_tensor,
  375. lora_weights,
  376. our_outputs,
  377. indices,
  378. slice_offset,
  379. slice_size=hidden_size,
  380. add_inputs=True,
  381. )
  382. ref_torch_groupgemm(
  383. ref_outputs[:, slice_offset:slice_offset + hidden_size],
  384. inputs_tensor,
  385. lora_weights,
  386. lora_indices_tensor,
  387. seq_len_tensor,
  388. batches,
  389. 1.0,
  390. op_type="expand",
  391. )
  392. slice_offset += hidden_size
  393. assert_close(our_outputs, ref_outputs)