test_prefix_prefill.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import math
  2. import random
  3. import time
  4. import pytest
  5. import torch
  6. from xformers import ops as xops
  7. from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
  8. from aphrodite.attention.backends.xformers import _make_alibi_bias
  9. from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
  10. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
  11. NUM_HEADS = [64]
  12. NUM_QUERIES_PER_KV = [1, 8, 64]
  13. HEAD_SIZES = [128, 96, 24]
  14. DTYPES = [torch.float16]
  15. CUDA_DEVICES = [
  16. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  17. ]
  18. SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
  19. KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
  20. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  21. @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
  22. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  23. @pytest.mark.parametrize("dtype", DTYPES)
  24. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
  25. @pytest.mark.parametrize("device", CUDA_DEVICES)
  26. @pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
  27. @torch.inference_mode()
  28. def test_contexted_kv_attention(
  29. num_heads: int,
  30. num_queries_per_kv: int,
  31. head_size: int,
  32. sliding_window: int,
  33. dtype: torch.dtype,
  34. kv_cache_dtype: str,
  35. device: str,
  36. ) -> None:
  37. random.seed(0)
  38. torch.manual_seed(0)
  39. if torch.cuda.is_available():
  40. torch.cuda.manual_seed(0)
  41. torch.set_default_device(device)
  42. # Need this, otherwise when we capture the graph the process
  43. # for GPU 1 would run on both GPU0 and GPU1 and things would hang
  44. #
  45. # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
  46. torch.cuda.set_device(device)
  47. MAX_SEQ_LEN = 1024
  48. MAX_CTX_LEN = 1024
  49. BS = 10
  50. cache_size = 640
  51. block_size = 32
  52. max_block_per_request = 64
  53. query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
  54. ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
  55. seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
  56. num_kv_heads = num_heads // num_queries_per_kv
  57. num_tokens = sum(query_lens)
  58. query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
  59. query.uniform_(-1e-3, 1e-3)
  60. output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
  61. kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
  62. kv.uniform_(-1e-3, 1e-3)
  63. key, value = kv.unbind(dim=1)
  64. if kv_cache_dtype == "auto":
  65. cache_dtype = dtype
  66. else:
  67. cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
  68. k_cache = torch.zeros(cache_size,
  69. block_size,
  70. num_kv_heads,
  71. head_size,
  72. dtype=cache_dtype)
  73. v_cache = torch.zeros(cache_size,
  74. block_size,
  75. num_kv_heads,
  76. head_size,
  77. dtype=cache_dtype)
  78. k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
  79. v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
  80. values = torch.arange(0, cache_size, dtype=torch.long)
  81. values = values[torch.randperm(cache_size)]
  82. block_table = values[:BS * max_block_per_request].view(
  83. BS, max_block_per_request)
  84. b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
  85. b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
  86. b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
  87. dtype=torch.long),
  88. dim=0)
  89. max_input_len = MAX_SEQ_LEN
  90. # copy kv to cache
  91. b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
  92. dtype=torch.long),
  93. dim=0)
  94. for i in range(BS):
  95. for j in range(query_lens[i]):
  96. k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
  97. j])
  98. v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
  99. b_ctx_len[i] + j])
  100. cur_ctx = 0
  101. block_id = 0
  102. while cur_ctx < b_ctx_len[i]:
  103. start_loc = b_seq_start_loc[i] + cur_ctx
  104. if cur_ctx + block_size > b_ctx_len[i]:
  105. end_loc = b_seq_start_loc[i] + b_ctx_len[i]
  106. else:
  107. end_loc = start_loc + block_size
  108. start_slot = block_table[i, block_id] * block_size
  109. end_slot = start_slot + end_loc - start_loc
  110. k_cache.view(-1, num_kv_heads,
  111. head_size)[start_slot:end_slot].copy_(
  112. key[start_loc:end_loc])
  113. v_cache.view(-1, num_kv_heads,
  114. head_size)[start_slot:end_slot].copy_(
  115. value[start_loc:end_loc])
  116. cur_ctx += block_size
  117. block_id += 1
  118. # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
  119. # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
  120. k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
  121. 8).permute(0, 2, 3, 1, 4).contiguous()
  122. # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
  123. # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
  124. v_cache = v_cache.view(-1, block_size, num_kv_heads,
  125. head_size).permute(0, 2, 3, 1).contiguous()
  126. # Warm up the Triton kernel by calling it once before actually measuring
  127. # generation time
  128. context_attention_fwd(query,
  129. k,
  130. v,
  131. output,
  132. kv_cache_dtype,
  133. k_cache,
  134. v_cache,
  135. block_table,
  136. b_start_loc,
  137. b_seq_len,
  138. b_ctx_len,
  139. max_input_len,
  140. sliding_window=sliding_window)
  141. torch.cuda.synchronize()
  142. start_time = time.time()
  143. context_attention_fwd(query,
  144. k,
  145. v,
  146. output,
  147. kv_cache_dtype,
  148. k_cache,
  149. v_cache,
  150. block_table,
  151. b_start_loc,
  152. b_seq_len,
  153. b_ctx_len,
  154. max_input_len,
  155. sliding_window=sliding_window)
  156. torch.cuda.synchronize()
  157. end_time = time.time()
  158. print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
  159. scale = float(1.0 / (head_size**0.5))
  160. attn_op = xops.fmha.cutlass.FwOp()
  161. if num_kv_heads != num_heads:
  162. # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
  163. # project the key and value tensors to the desired number of
  164. # heads.
  165. #
  166. # see also: aphrodite/model_executor/layers/attention.py
  167. query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
  168. query.shape[-1])
  169. key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
  170. num_queries_per_kv, key.shape[-1])
  171. value = value[:, :,
  172. None, :].expand(value.shape[0], num_kv_heads,
  173. num_queries_per_kv, value.shape[-1])
  174. query = query.unsqueeze(0)
  175. key = key.unsqueeze(0)
  176. value = value.unsqueeze(0)
  177. attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
  178. query_lens, seq_lens)
  179. if sliding_window > 0:
  180. attn_bias = attn_bias.make_local_attention_from_bottomright(
  181. sliding_window)
  182. output_ref = xops.memory_efficient_attention_forward(
  183. query,
  184. key,
  185. value,
  186. attn_bias=attn_bias,
  187. p=0.0,
  188. scale=scale,
  189. op=attn_op,
  190. )
  191. torch.cuda.synchronize()
  192. start_time = time.time()
  193. output_ref = xops.memory_efficient_attention_forward(
  194. query,
  195. key,
  196. value,
  197. attn_bias=attn_bias,
  198. p=0.0,
  199. scale=scale,
  200. op=attn_op,
  201. )
  202. torch.cuda.synchronize()
  203. end_time = time.time()
  204. print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
  205. output_ref = output_ref.reshape(output.shape)
  206. atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
  207. torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
  208. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  209. @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
  210. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  211. @pytest.mark.parametrize("dtype", DTYPES)
  212. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
  213. @pytest.mark.parametrize("device", CUDA_DEVICES)
  214. @torch.inference_mode()
  215. def test_contexted_kv_attention_alibi(
  216. num_heads: int,
  217. num_queries_per_kv: int,
  218. head_size: int,
  219. dtype: torch.dtype,
  220. kv_cache_dtype: str,
  221. device: str,
  222. ) -> None:
  223. random.seed(0)
  224. torch.manual_seed(0)
  225. if torch.cuda.is_available():
  226. torch.cuda.manual_seed(0)
  227. torch.set_default_device(device)
  228. # Need this, otherwise when we capture the graph the process
  229. # for GPU 1 would run on both GPU0 and GPU1 and things would hang
  230. #
  231. # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
  232. torch.cuda.set_device(device)
  233. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  234. # Fork from: aphrodite/aphrodite/model_executor/models/bloom.py#L44
  235. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  236. base = torch.tensor(
  237. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  238. dtype=torch.float32,
  239. )
  240. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  241. slopes = torch.pow(base, powers)
  242. if closest_power_of_2 != total_num_heads:
  243. extra_base = torch.tensor(
  244. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  245. dtype=torch.float32,
  246. )
  247. num_remaining_heads = min(closest_power_of_2,
  248. total_num_heads - closest_power_of_2)
  249. extra_powers = torch.arange(start=1,
  250. end=1 + 2 * num_remaining_heads,
  251. step=2,
  252. dtype=torch.int32)
  253. slopes = torch.cat(
  254. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  255. return slopes
  256. alibi_slopes = _get_alibi_slopes(num_heads).to(device)
  257. MAX_SEQ_LEN = 1024
  258. MAX_CTX_LEN = 1024
  259. BS = 10
  260. cache_size = 640
  261. block_size = 32
  262. max_block_per_request = 64
  263. query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
  264. ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
  265. seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
  266. num_kv_heads = num_heads // num_queries_per_kv
  267. num_tokens = sum(query_lens)
  268. query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
  269. query.uniform_(-1e-3, 1e-3)
  270. output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
  271. kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
  272. kv.uniform_(-1e-3, 1e-3)
  273. key, value = kv.unbind(dim=1)
  274. if kv_cache_dtype == "auto":
  275. cache_dtype = dtype
  276. else:
  277. cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
  278. k_cache = torch.zeros(cache_size,
  279. block_size,
  280. num_kv_heads,
  281. head_size,
  282. dtype=cache_dtype)
  283. v_cache = torch.zeros(cache_size,
  284. block_size,
  285. num_kv_heads,
  286. head_size,
  287. dtype=cache_dtype)
  288. k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
  289. v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
  290. values = torch.arange(0, cache_size, dtype=torch.long)
  291. values = values[torch.randperm(cache_size)]
  292. block_table = values[:BS * max_block_per_request].view(
  293. BS, max_block_per_request)
  294. b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
  295. b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
  296. b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
  297. dtype=torch.long),
  298. dim=0)
  299. max_input_len = MAX_SEQ_LEN
  300. # copy kv to cache
  301. b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
  302. dtype=torch.long),
  303. dim=0)
  304. for i in range(BS):
  305. for j in range(query_lens[i]):
  306. k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
  307. j])
  308. v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
  309. b_ctx_len[i] + j])
  310. cur_ctx = 0
  311. block_id = 0
  312. while cur_ctx < b_ctx_len[i]:
  313. start_loc = b_seq_start_loc[i] + cur_ctx
  314. if cur_ctx + block_size > b_ctx_len[i]:
  315. end_loc = b_seq_start_loc[i] + b_ctx_len[i]
  316. else:
  317. end_loc = start_loc + block_size
  318. start_slot = block_table[i, block_id] * block_size
  319. end_slot = start_slot + end_loc - start_loc
  320. k_cache.view(-1, num_kv_heads,
  321. head_size)[start_slot:end_slot].copy_(
  322. key[start_loc:end_loc])
  323. v_cache.view(-1, num_kv_heads,
  324. head_size)[start_slot:end_slot].copy_(
  325. value[start_loc:end_loc])
  326. cur_ctx += block_size
  327. block_id += 1
  328. # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
  329. # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
  330. k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
  331. 8).permute(0, 2, 3, 1, 4).contiguous()
  332. # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
  333. # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
  334. v_cache = v_cache.view(-1, block_size, num_kv_heads,
  335. head_size).permute(0, 2, 3, 1).contiguous()
  336. # Warm up the Triton kernel by calling it once before actually measuring
  337. # generation time
  338. context_attention_fwd(query,
  339. k,
  340. v,
  341. output,
  342. kv_cache_dtype,
  343. k_cache,
  344. v_cache,
  345. block_table,
  346. b_start_loc,
  347. b_seq_len,
  348. b_ctx_len,
  349. max_input_len,
  350. alibi_slopes=alibi_slopes)
  351. torch.cuda.synchronize()
  352. start_time = time.time()
  353. context_attention_fwd(query,
  354. k,
  355. v,
  356. output,
  357. kv_cache_dtype,
  358. k_cache,
  359. v_cache,
  360. block_table,
  361. b_start_loc,
  362. b_seq_len,
  363. b_ctx_len,
  364. max_input_len,
  365. alibi_slopes=alibi_slopes)
  366. torch.cuda.synchronize()
  367. end_time = time.time()
  368. print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
  369. scale = float(1.0 / (head_size**0.5))
  370. # NOTE(DefTruth): In order to reuse _make_alibi_bias function,
  371. # we have to pad query tensor before MQA/GQA expanding.
  372. if query.shape[0] != key.shape[0]:
  373. query_pad = torch.empty(sum(seq_lens),
  374. num_heads,
  375. head_size,
  376. dtype=dtype)
  377. query_pad.uniform_(-1e-3, 1e-3)
  378. seq_start = 0
  379. query_start = 0
  380. for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
  381. seq_end = seq_start + seq_len
  382. query_end = query_start + query_len
  383. query_pad[seq_start:seq_end, ...] = torch.cat([
  384. torch.zeros(
  385. seq_len - query_len, num_heads, head_size, dtype=dtype),
  386. query[query_start:query_end, ...]
  387. ],
  388. dim=0)
  389. seq_start += seq_len
  390. query_start += query_len
  391. query = query_pad
  392. if num_kv_heads != num_heads:
  393. # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
  394. # project the key and value tensors to the desired number of
  395. # heads.
  396. #
  397. # see also: aphrodite/model_executor/layers/attention.py
  398. query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
  399. query.shape[-1])
  400. key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
  401. num_queries_per_kv, key.shape[-1])
  402. value = value[:, :,
  403. None, :].expand(value.shape[0], num_kv_heads,
  404. num_queries_per_kv, value.shape[-1])
  405. query = query.unsqueeze(0)
  406. key = key.unsqueeze(0)
  407. value = value.unsqueeze(0)
  408. attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
  409. output_ref = torch.empty_like(output)
  410. seq_start = 0
  411. query_start = 0
  412. start_time = time.time()
  413. # Attention with alibi slopes.
  414. # FIXME: Because xformers does not support dynamic sequence
  415. # lengths with custom attention bias, we process each prompt one by
  416. # one. This is inefficient, especially when we have many short prompts.
  417. # modified from: aphrodite/attention/backends/xformers.py#L343
  418. for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
  419. seq_end = seq_start + seq_len
  420. query_end = query_start + query_len
  421. out = xops.memory_efficient_attention_forward(query[:,
  422. seq_start:seq_end],
  423. key[:,
  424. seq_start:seq_end],
  425. value[:,
  426. seq_start:seq_end],
  427. attn_bias=attn_bias[i],
  428. p=0.0,
  429. scale=scale)
  430. out = out.view_as(query[:, seq_start:seq_end]).view(
  431. seq_len, num_heads, head_size)
  432. output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
  433. ...])
  434. seq_start += seq_len
  435. query_start += query_len
  436. torch.cuda.synchronize()
  437. end_time = time.time()
  438. print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
  439. atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
  440. torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)