test_punica_variation.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. """
  2. This script is mainly used to test whether trtion kernels can run normally
  3. under different conditions, including various batches, numbers of LoRA , and
  4. maximum ranks.
  5. """
  6. from unittest.mock import patch
  7. import pytest
  8. import torch
  9. from aphrodite.common.utils import seed_everything
  10. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  11. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  12. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  13. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  14. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  15. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  16. from aphrodite.triton_utils.libentry import LibEntry
  17. from .utils import (generate_data, generate_data_for_expand_nslices,
  18. ref_torch_groupgemm)
  19. HIDDEN_SIZES = [4097]
  20. BATCHES = [1, 4, 16, 32]
  21. NUM_LORA = [1, 8, 32, 128]
  22. DTYPES = [torch.float16, torch.bfloat16]
  23. MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
  24. SCALES = [0.5]
  25. SEED = [0]
  26. CUDA_DEVICES = [f"cuda:{0}"]
  27. def assert_close(a, b):
  28. rtol, atol = {
  29. torch.float16: (6e-2, 6e-2),
  30. torch.bfloat16: (6e-2, 6e-2),
  31. torch.float32: (1e-2, 1e-2),
  32. }[a.dtype]
  33. torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
  34. @pytest.mark.parametrize("batches", BATCHES)
  35. @pytest.mark.parametrize("num_loras", NUM_LORA)
  36. @pytest.mark.parametrize("rank", MAX_RANKS)
  37. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  38. @pytest.mark.parametrize("scaling", SCALES)
  39. @pytest.mark.parametrize("dtype", DTYPES)
  40. @pytest.mark.parametrize("op_type", ["shrink", "expand"])
  41. @pytest.mark.parametrize("seed", SEED)
  42. @pytest.mark.parametrize("device", CUDA_DEVICES)
  43. def test_punica_sgmv(
  44. batches: int,
  45. num_loras: int,
  46. rank: int,
  47. hidden_size: int,
  48. scaling: float,
  49. dtype: torch.dtype,
  50. op_type: str,
  51. seed: int,
  52. device: str,
  53. ):
  54. torch.set_default_device(device)
  55. seed_everything(seed)
  56. seq_length = 128
  57. (
  58. inputs_tensor,
  59. lora_weights,
  60. our_out_tensor,
  61. ref_out_tensor,
  62. b_seq_start_loc,
  63. lora_indices_tensor,
  64. seq_len_tensor,
  65. indices,
  66. ) = generate_data(
  67. batches,
  68. hidden_size,
  69. num_loras,
  70. rank,
  71. seq_length,
  72. dtype,
  73. op_type,
  74. device,
  75. )
  76. max_seq_length = seq_len_tensor.max()
  77. token_nums = seq_len_tensor.sum().item()
  78. if isinstance(max_seq_length, tuple):
  79. max_seq_length = max_seq_length[0].item()
  80. else:
  81. max_seq_length = max_seq_length.item()
  82. if op_type == "shrink":
  83. sgmv_shrink(
  84. inputs_tensor,
  85. lora_weights,
  86. our_out_tensor,
  87. b_seq_start_loc,
  88. seq_len_tensor,
  89. lora_indices_tensor,
  90. batches,
  91. max_seq_length,
  92. token_nums,
  93. scaling,
  94. )
  95. else:
  96. sgmv_expand(
  97. inputs_tensor,
  98. lora_weights,
  99. our_out_tensor,
  100. b_seq_start_loc,
  101. seq_len_tensor,
  102. lora_indices_tensor,
  103. batches,
  104. max_seq_length,
  105. token_nums,
  106. add_inputs=True,
  107. )
  108. ref_torch_groupgemm(
  109. ref_out_tensor,
  110. inputs_tensor,
  111. lora_weights,
  112. lora_indices_tensor,
  113. seq_len_tensor,
  114. batches,
  115. scaling if op_type == "shrink" else 1.0,
  116. op_type,
  117. )
  118. if op_type == "shrink":
  119. ref_out_tensor = ref_out_tensor.to(torch.float32)
  120. assert_close(our_out_tensor, ref_out_tensor)
  121. @pytest.mark.parametrize("batches", BATCHES)
  122. @pytest.mark.parametrize("num_loras", NUM_LORA)
  123. @pytest.mark.parametrize("rank", MAX_RANKS)
  124. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  125. @pytest.mark.parametrize("scaling", SCALES)
  126. @pytest.mark.parametrize("dtype", DTYPES)
  127. @pytest.mark.parametrize("op_type", ["shrink", "expand"])
  128. @pytest.mark.parametrize("seed", SEED)
  129. @pytest.mark.parametrize("device", CUDA_DEVICES)
  130. def test_punica_bgmv(
  131. batches: int,
  132. num_loras: int,
  133. rank: int,
  134. hidden_size: int,
  135. scaling: float,
  136. dtype: torch.dtype,
  137. op_type: str,
  138. seed: int,
  139. device: str,
  140. ):
  141. from aphrodite.lora.ops.bgmv_expand import _bgmv_expand_kernel
  142. from aphrodite.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
  143. torch.set_default_device(device)
  144. seed_everything(seed)
  145. seq_length = 1
  146. (
  147. inputs_tensor,
  148. lora_weights,
  149. our_out_tensor,
  150. ref_out_tensor,
  151. b_seq_start_loc,
  152. lora_indices_tensor,
  153. seq_len_tensor,
  154. indices,
  155. ) = generate_data(
  156. batches,
  157. hidden_size,
  158. num_loras,
  159. rank,
  160. seq_length,
  161. dtype,
  162. op_type,
  163. device,
  164. )
  165. if op_type == "shrink":
  166. # The current _bgmv_shrink_kernel does not require the libentry
  167. # decoration. The purpose of adding this patch is to test the
  168. # correctness of libentry.
  169. with patch(
  170. "aphrodite.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
  171. LibEntry(_bgmv_shrink_kernel),
  172. ):
  173. bgmv_shrink(
  174. inputs_tensor,
  175. lora_weights,
  176. our_out_tensor,
  177. indices,
  178. scaling,
  179. )
  180. else:
  181. # ditto
  182. with patch(
  183. "aphrodite.lora.ops.bgmv_expand._bgmv_expand_kernel",
  184. LibEntry(_bgmv_expand_kernel),
  185. ):
  186. bgmv_expand(
  187. inputs_tensor,
  188. lora_weights,
  189. our_out_tensor,
  190. indices,
  191. add_inputs=True,
  192. )
  193. ref_torch_groupgemm(
  194. ref_out_tensor,
  195. inputs_tensor,
  196. lora_weights,
  197. lora_indices_tensor,
  198. seq_len_tensor,
  199. batches,
  200. scaling if op_type == "shrink" else 1.0,
  201. op_type,
  202. )
  203. if op_type == "shrink":
  204. ref_out_tensor = ref_out_tensor.to(torch.float32)
  205. assert_close(our_out_tensor, ref_out_tensor)
  206. @pytest.mark.parametrize("batches", BATCHES)
  207. @pytest.mark.parametrize("num_loras", NUM_LORA)
  208. @pytest.mark.parametrize("rank", MAX_RANKS)
  209. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  210. @pytest.mark.parametrize("nslices", [2, 3])
  211. @pytest.mark.parametrize("dtype", DTYPES)
  212. @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
  213. @pytest.mark.parametrize("seed", SEED)
  214. @pytest.mark.parametrize("device", CUDA_DEVICES)
  215. def test_punica_expand_nslices(
  216. batches: int,
  217. num_loras: int,
  218. rank: int,
  219. hidden_size: int,
  220. nslices: int,
  221. dtype: torch.dtype,
  222. op_type: str,
  223. seed: int,
  224. device: str,
  225. ):
  226. from aphrodite.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
  227. torch.set_default_device(device)
  228. seed_everything(seed)
  229. seq_length = 128 if op_type == "sgmv" else 1
  230. (
  231. inputs_tensor,
  232. lora_weights_lst,
  233. our_outputs,
  234. ref_outputs,
  235. b_seq_start_loc,
  236. lora_indices_tensor,
  237. seq_len_tensor,
  238. indices,
  239. ) = generate_data_for_expand_nslices(
  240. batches,
  241. hidden_size,
  242. num_loras,
  243. rank,
  244. seq_length,
  245. dtype,
  246. nslices,
  247. device,
  248. )
  249. max_seq_length = seq_len_tensor.max()
  250. token_nums = seq_len_tensor.sum().item()
  251. if isinstance(max_seq_length, tuple):
  252. max_seq_length = max_seq_length[0].item()
  253. else:
  254. max_seq_length = max_seq_length.item()
  255. slice_offset = 0
  256. for index in range(nslices):
  257. lora_weights = lora_weights_lst[index]
  258. if op_type == "sgmv":
  259. sgmv_expand_slice(
  260. inputs_tensor,
  261. lora_weights,
  262. our_outputs,
  263. b_seq_start_loc,
  264. seq_len_tensor,
  265. lora_indices_tensor,
  266. batches,
  267. max_seq_length,
  268. token_nums,
  269. slice_offset,
  270. hidden_size,
  271. add_inputs=True,
  272. )
  273. else:
  274. # The current _bgmv_expand_slice_kernel does not require the
  275. # libentry decoration. The purpose of adding this patch is to test
  276. # the correctness of libentry.
  277. with patch(
  278. "aphrodite.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
  279. LibEntry(_bgmv_expand_slice_kernel),
  280. ):
  281. bgmv_expand_slice(
  282. inputs_tensor,
  283. lora_weights,
  284. our_outputs,
  285. indices,
  286. slice_offset,
  287. slice_size=hidden_size,
  288. add_inputs=True,
  289. )
  290. ref_torch_groupgemm(
  291. ref_outputs[:, slice_offset:slice_offset + hidden_size],
  292. inputs_tensor,
  293. lora_weights,
  294. lora_indices_tensor,
  295. seq_len_tensor,
  296. batches,
  297. 1.0,
  298. op_type="expand",
  299. )
  300. slice_offset += hidden_size
  301. assert_close(our_outputs, ref_outputs)