test_punica_variation.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. """
  2. This script is mainly used to test whether trtion kernels can run normally
  3. under different conditions, including various batches, numbers of LoRA , and
  4. maximum ranks.
  5. """
  6. import random
  7. from unittest.mock import patch
  8. import pytest
  9. import torch
  10. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  11. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  12. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  13. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  14. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  15. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  16. from aphrodite.triton_utils.libentry import LibEntry
  17. from .utils import (generate_data, generate_data_for_expand_nslices,
  18. ref_torch_groupgemm)
  19. HIDDEN_SIZES = [4097]
  20. BATCHES = [1, 4, 16, 32]
  21. NUM_LORA = [1, 8, 32, 128]
  22. DTYPES = [torch.float16, torch.bfloat16]
  23. MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
  24. SCALES = [0.5]
  25. SEED = [0]
  26. CUDA_DEVICES = [f"cuda:{0}"]
  27. def assert_close(a, b):
  28. rtol, atol = {
  29. torch.float16: (6e-2, 6e-2),
  30. torch.bfloat16: (6e-2, 6e-2),
  31. torch.float32: (1e-2, 1e-2),
  32. }[a.dtype]
  33. torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
  34. @pytest.mark.parametrize("batches", BATCHES)
  35. @pytest.mark.parametrize("num_loras", NUM_LORA)
  36. @pytest.mark.parametrize("rank", MAX_RANKS)
  37. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  38. @pytest.mark.parametrize("scaling", SCALES)
  39. @pytest.mark.parametrize("dtype", DTYPES)
  40. @pytest.mark.parametrize("op_type", ["shrink", "expand"])
  41. @pytest.mark.parametrize("seed", SEED)
  42. @pytest.mark.parametrize("device", CUDA_DEVICES)
  43. def test_punica_sgmv(
  44. batches: int,
  45. num_loras: int,
  46. rank: int,
  47. hidden_size: int,
  48. scaling: float,
  49. dtype: torch.dtype,
  50. op_type: str,
  51. seed: int,
  52. device: str,
  53. ):
  54. random.seed(seed)
  55. torch.set_default_device(device)
  56. torch.random.manual_seed(seed)
  57. if torch.cuda.is_available():
  58. torch.cuda.manual_seed(seed)
  59. seq_length = 128
  60. (
  61. inputs_tensor,
  62. lora_weights,
  63. our_out_tensor,
  64. ref_out_tensor,
  65. b_seq_start_loc,
  66. lora_indices_tensor,
  67. seq_len_tensor,
  68. indices,
  69. ) = generate_data(
  70. batches,
  71. hidden_size,
  72. num_loras,
  73. rank,
  74. seq_length,
  75. dtype,
  76. op_type,
  77. device,
  78. )
  79. max_seq_length = seq_len_tensor.max()
  80. if isinstance(max_seq_length, tuple):
  81. max_seq_length = max_seq_length[0].item()
  82. else:
  83. max_seq_length = max_seq_length.item()
  84. if op_type == "shrink":
  85. sgmv_shrink(
  86. inputs_tensor,
  87. lora_weights,
  88. our_out_tensor,
  89. b_seq_start_loc,
  90. seq_len_tensor,
  91. lora_indices_tensor,
  92. batches,
  93. max_seq_length,
  94. scaling,
  95. )
  96. else:
  97. sgmv_expand(
  98. inputs_tensor,
  99. lora_weights,
  100. our_out_tensor,
  101. b_seq_start_loc,
  102. seq_len_tensor,
  103. lora_indices_tensor,
  104. batches,
  105. max_seq_length,
  106. add_inputs=True,
  107. )
  108. ref_torch_groupgemm(
  109. ref_out_tensor,
  110. inputs_tensor,
  111. lora_weights,
  112. lora_indices_tensor,
  113. seq_len_tensor,
  114. batches,
  115. scaling if op_type == "shrink" else 1.0,
  116. op_type,
  117. )
  118. if op_type == "shrink":
  119. ref_out_tensor = ref_out_tensor.to(torch.float32)
  120. assert_close(our_out_tensor, ref_out_tensor)
  121. @pytest.mark.parametrize("batches", BATCHES)
  122. @pytest.mark.parametrize("num_loras", NUM_LORA)
  123. @pytest.mark.parametrize("rank", MAX_RANKS)
  124. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  125. @pytest.mark.parametrize("scaling", SCALES)
  126. @pytest.mark.parametrize("dtype", DTYPES)
  127. @pytest.mark.parametrize("op_type", ["shrink", "expand"])
  128. @pytest.mark.parametrize("seed", SEED)
  129. @pytest.mark.parametrize("device", CUDA_DEVICES)
  130. def test_punica_bgmv(
  131. batches: int,
  132. num_loras: int,
  133. rank: int,
  134. hidden_size: int,
  135. scaling: float,
  136. dtype: torch.dtype,
  137. op_type: str,
  138. seed: int,
  139. device: str,
  140. ):
  141. from aphrodite.lora.ops.bgmv_expand import _bgmv_expand_kernel
  142. from aphrodite.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
  143. random.seed(seed)
  144. torch.set_default_device(device)
  145. torch.random.manual_seed(seed)
  146. if torch.cuda.is_available():
  147. torch.cuda.manual_seed(seed)
  148. seq_length = 1
  149. (
  150. inputs_tensor,
  151. lora_weights,
  152. our_out_tensor,
  153. ref_out_tensor,
  154. b_seq_start_loc,
  155. lora_indices_tensor,
  156. seq_len_tensor,
  157. indices,
  158. ) = generate_data(
  159. batches,
  160. hidden_size,
  161. num_loras,
  162. rank,
  163. seq_length,
  164. dtype,
  165. op_type,
  166. device,
  167. )
  168. if op_type == "shrink":
  169. # The current _bgmv_shrink_kernel does not require the libentry
  170. # decoration. The purpose of adding this patch is to test the
  171. # correctness of libentry.
  172. with patch(
  173. "aphrodite.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
  174. LibEntry(_bgmv_shrink_kernel),
  175. ):
  176. bgmv_shrink(
  177. inputs_tensor,
  178. lora_weights,
  179. our_out_tensor,
  180. indices,
  181. scaling,
  182. )
  183. else:
  184. # ditto
  185. with patch(
  186. "aphrodite.lora.ops.bgmv_expand._bgmv_expand_kernel",
  187. LibEntry(_bgmv_expand_kernel),
  188. ):
  189. bgmv_expand(
  190. inputs_tensor,
  191. lora_weights,
  192. our_out_tensor,
  193. indices,
  194. add_inputs=True,
  195. )
  196. ref_torch_groupgemm(
  197. ref_out_tensor,
  198. inputs_tensor,
  199. lora_weights,
  200. lora_indices_tensor,
  201. seq_len_tensor,
  202. batches,
  203. scaling if op_type == "shrink" else 1.0,
  204. op_type,
  205. )
  206. if op_type == "shrink":
  207. ref_out_tensor = ref_out_tensor.to(torch.float32)
  208. assert_close(our_out_tensor, ref_out_tensor)
  209. @pytest.mark.parametrize("batches", BATCHES)
  210. @pytest.mark.parametrize("num_loras", NUM_LORA)
  211. @pytest.mark.parametrize("rank", MAX_RANKS)
  212. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  213. @pytest.mark.parametrize("nslices", [2, 3])
  214. @pytest.mark.parametrize("dtype", DTYPES)
  215. @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
  216. @pytest.mark.parametrize("seed", SEED)
  217. @pytest.mark.parametrize("device", CUDA_DEVICES)
  218. def test_punica_expand_nslices(
  219. batches: int,
  220. num_loras: int,
  221. rank: int,
  222. hidden_size: int,
  223. nslices: int,
  224. dtype: torch.dtype,
  225. op_type: str,
  226. seed: int,
  227. device: str,
  228. ):
  229. from aphrodite.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
  230. random.seed(seed)
  231. torch.set_default_device(device)
  232. torch.random.manual_seed(seed)
  233. if torch.cuda.is_available():
  234. torch.cuda.manual_seed(seed)
  235. seq_length = 128 if op_type == "sgmv" else 1
  236. (
  237. inputs_tensor,
  238. lora_weights_lst,
  239. our_outputs,
  240. ref_outputs,
  241. b_seq_start_loc,
  242. lora_indices_tensor,
  243. seq_len_tensor,
  244. indices,
  245. ) = generate_data_for_expand_nslices(
  246. batches,
  247. hidden_size,
  248. num_loras,
  249. rank,
  250. seq_length,
  251. dtype,
  252. nslices,
  253. device,
  254. )
  255. max_seq_length = seq_len_tensor.max()
  256. if isinstance(max_seq_length, tuple):
  257. max_seq_length = max_seq_length[0].item()
  258. else:
  259. max_seq_length = max_seq_length.item()
  260. slice_offset = 0
  261. for index in range(nslices):
  262. lora_weights = lora_weights_lst[index]
  263. if op_type == "sgmv":
  264. sgmv_expand_slice(
  265. inputs_tensor,
  266. lora_weights,
  267. our_outputs,
  268. b_seq_start_loc,
  269. seq_len_tensor,
  270. lora_indices_tensor,
  271. batches,
  272. max_seq_length,
  273. slice_offset,
  274. hidden_size,
  275. add_inputs=True,
  276. )
  277. else:
  278. # The current _bgmv_expand_slice_kernel does not require the
  279. # libentry decoration. The purpose of adding this patch is to test
  280. # the correctness of libentry.
  281. with patch(
  282. "aphrodite.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
  283. LibEntry(_bgmv_expand_slice_kernel),
  284. ):
  285. bgmv_expand_slice(
  286. inputs_tensor,
  287. lora_weights,
  288. our_outputs,
  289. indices,
  290. slice_offset,
  291. slice_size=hidden_size,
  292. add_inputs=True,
  293. )
  294. ref_torch_groupgemm(
  295. ref_outputs[:, slice_offset:slice_offset + hidden_size],
  296. inputs_tensor,
  297. lora_weights,
  298. lora_indices_tensor,
  299. seq_len_tensor,
  300. batches,
  301. 1.0,
  302. op_type="expand",
  303. )
  304. slice_offset += hidden_size
  305. assert_close(our_outputs, ref_outputs)