test_rotary.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import math
  2. import random
  3. import pytest
  4. import torch
  5. import torch.nn.functional as F
  6. from einops import rearrange
  7. from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
  8. from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
  9. from flash_attn.bert_padding import pad_input, unpad_input
  10. is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
  11. def generate_cos_sin(seqlen, rotary_dim, device, dtype):
  12. assert rotary_dim % 2 == 0
  13. angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
  14. cos = torch.cos(angle).to(dtype=dtype)
  15. sin = torch.sin(angle).to(dtype=dtype)
  16. return cos, sin
  17. def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device):
  18. if seqlen_offsets_type == 0:
  19. return 0
  20. elif seqlen_offsets_type is int:
  21. return torch.randint(0, seqlen + 1, (1,)).item()
  22. elif seqlen_offsets_type is torch.Tensor:
  23. return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device)
  24. def index_cos_sin(cos, sin, seqlen_offsets, seqlen):
  25. if isinstance(seqlen_offsets, torch.Tensor):
  26. batch_size = seqlen_offsets.shape[0]
  27. arange = rearrange(torch.arange(seqlen, device=cos.device), "s -> 1 s")
  28. idx = rearrange(seqlen_offsets, "b -> b 1") + arange
  29. cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
  30. sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
  31. else:
  32. cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
  33. sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
  34. return cos_pt, sin_pt
  35. @pytest.mark.parametrize(
  36. "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
  37. )
  38. # @pytest.mark.parametrize('dtype', ([torch.float16]))
  39. @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
  40. # @pytest.mark.parametrize("seqlen_offsets_type", [0])
  41. @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
  42. # @pytest.mark.parametrize('rotary_fraction', [1.0])
  43. @pytest.mark.parametrize("interleaved", [False, True])
  44. # @pytest.mark.parametrize('interleaved', [True])
  45. @pytest.mark.parametrize("inplace", [False, True])
  46. # @pytest.mark.parametrize('inplace', [False])
  47. def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
  48. rtol = 1e-3
  49. batch_size = 32
  50. nheads = 4
  51. seqlen = 217
  52. headdim = 128
  53. device = "cuda"
  54. rotary_dim = int(rotary_fraction * headdim)
  55. torch.manual_seed(42)
  56. x = torch.randn(
  57. batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
  58. )
  59. x_pt = x.detach().clone().requires_grad_()
  60. cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
  61. seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
  62. out = apply_rotary_emb(
  63. x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
  64. )
  65. cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
  66. out_pt = apply_rotary_emb_torch(
  67. x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  68. ).to(dtype=dtype)
  69. print(f"Output max diff: {(out - out_pt).abs().max().item()}")
  70. g = torch.randn_like(out)
  71. g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
  72. out.backward(g)
  73. out_pt.backward(g_pt)
  74. print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
  75. if not inplace:
  76. assert torch.equal(x, x_pt)
  77. # Numerical error if we just do any arithmetic
  78. atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
  79. assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
  80. atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
  81. assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
  82. @pytest.mark.parametrize(
  83. "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
  84. )
  85. # @pytest.mark.parametrize('dtype', ([torch.float16]))
  86. @pytest.mark.parametrize("gqa", [False, True])
  87. # @pytest.mark.parametrize("gqa", [False])
  88. @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
  89. # @pytest.mark.parametrize("seqlen_offsets_type", [0])
  90. @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
  91. # @pytest.mark.parametrize('rotary_fraction', [1.0])
  92. @pytest.mark.parametrize("interleaved", [False, True])
  93. # @pytest.mark.parametrize('interleaved', [False])
  94. def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, dtype):
  95. rtol = 1e-3
  96. batch_size = 32
  97. nheads = 4
  98. seqlen = 512
  99. headdim = 128
  100. device = "cuda"
  101. rotary_dim = int(rotary_fraction * headdim)
  102. torch.manual_seed(42)
  103. if not gqa:
  104. qkv = torch.randn(
  105. batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
  106. )
  107. else:
  108. nheads_k = nheads // 2
  109. qkv = torch.randn(
  110. batch_size, seqlen, nheads + nheads_k * 2, headdim, dtype=dtype, device=device, requires_grad=True
  111. )
  112. qkv_pt = qkv.detach().clone().requires_grad_()
  113. cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
  114. seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
  115. out = apply_rotary_emb_qkv_(
  116. qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved,
  117. num_heads_q=None if not gqa else nheads
  118. )
  119. cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
  120. if not gqa:
  121. q_pt, k_pt, v_pt = qkv_pt.unbind(2)
  122. else:
  123. q_pt, k_pt, v_pt = qkv_pt.split([nheads, nheads_k, nheads_k], dim=2)
  124. q_pt = apply_rotary_emb_torch(
  125. q_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  126. ).to(dtype=dtype)
  127. k_pt = apply_rotary_emb_torch(
  128. k_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  129. ).to(dtype=dtype)
  130. if not gqa:
  131. out_pt = torch.stack([q_pt, k_pt, v_pt], dim=2)
  132. else:
  133. out_pt = torch.cat([q_pt, k_pt, v_pt], dim=2)
  134. print(f"Output max diff: {(out - out_pt).abs().max().item()}")
  135. g = torch.randn_like(out)
  136. g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
  137. out.backward(g)
  138. out_pt.backward(g_pt)
  139. print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
  140. # Numerical error if we just do any arithmetic
  141. atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
  142. assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
  143. atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
  144. assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
  145. @pytest.mark.parametrize(
  146. "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
  147. )
  148. # @pytest.mark.parametrize('dtype', ([torch.float16]))
  149. @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
  150. # @pytest.mark.parametrize("seqlen_offsets_type", [0])
  151. @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
  152. # @pytest.mark.parametrize('rotary_fraction', [1.0])
  153. @pytest.mark.parametrize("interleaved", [False, True])
  154. # @pytest.mark.parametrize('interleaved', [False])
  155. def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
  156. rtol = 1e-3
  157. batch_size = 32
  158. nheads = 4
  159. seqlen = 781
  160. headdim = 64
  161. device = "cuda"
  162. rotary_dim = int(rotary_fraction * headdim)
  163. torch.manual_seed(42)
  164. kv = torch.randn(
  165. batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
  166. )
  167. kv_pt = kv.detach().clone().requires_grad_()
  168. cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
  169. seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
  170. out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved)
  171. cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
  172. k_pt = apply_rotary_emb_torch(
  173. kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  174. ).to(dtype=dtype)
  175. out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
  176. print(f"Output max diff: {(out - out_pt).abs().max().item()}")
  177. g = torch.randn_like(out)
  178. g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
  179. out.backward(g)
  180. out_pt.backward(g_pt)
  181. print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
  182. # Numerical error if we just do any arithmetic
  183. atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
  184. assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
  185. atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
  186. assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
  187. @pytest.mark.parametrize(
  188. "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
  189. )
  190. # @pytest.mark.parametrize("dtype", ([torch.float16]))
  191. @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
  192. # @pytest.mark.parametrize("seqlen_offsets_type", [0])
  193. @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
  194. # @pytest.mark.parametrize("rotary_fraction", [1.0])
  195. @pytest.mark.parametrize("interleaved", [False, True])
  196. # @pytest.mark.parametrize("interleaved", [True])
  197. @pytest.mark.parametrize("inplace", [False, True])
  198. # @pytest.mark.parametrize("inplace", [False])
  199. def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
  200. rtol = 1e-3
  201. batch_size = 32
  202. nheads = 4
  203. seqlen = 217
  204. headdim = 128
  205. device = "cuda"
  206. rotary_dim = int(rotary_fraction * headdim)
  207. torch.manual_seed(42)
  208. x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
  209. x_pt = x.detach().clone().requires_grad_()
  210. lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
  211. padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
  212. x_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(x, padding_mask)
  213. x_unpad_clone = x_unpad.clone()
  214. x_unpad = x_unpad.requires_grad_()
  215. cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
  216. seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
  217. out_unpad = apply_rotary_emb(
  218. x_unpad,
  219. cos,
  220. sin,
  221. seqlen_offsets=seqlen_offsets,
  222. interleaved=interleaved,
  223. inplace=inplace,
  224. cu_seqlens=cu_seqlens,
  225. max_seqlen=max_seqlen,
  226. )
  227. out = pad_input(out_unpad, indices, batch_size, seqlen)
  228. cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
  229. out_pt = apply_rotary_emb_torch(
  230. x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  231. ).to(dtype=dtype)
  232. out_pt = out_pt.masked_fill(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
  233. print(f"Output max diff: {(out - out_pt).abs().max().item()}")
  234. g = torch.randn_like(out)
  235. g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
  236. out.backward(g)
  237. out_pt.backward(g_pt)
  238. x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen)
  239. print(f"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}")
  240. if not inplace:
  241. assert torch.equal(x_unpad, x_unpad_clone)
  242. # Numerical error if we just do any arithmetic
  243. atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
  244. assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
  245. atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
  246. assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
  247. def test_compilation_count():
  248. batch_size = 1
  249. headdim = 128
  250. device = "cuda"
  251. dtype = torch.float16
  252. torch.manual_seed(42)
  253. from triton.runtime.jit import JITFunction
  254. from flash_attn.ops.triton.rotary import rotary_kernel
  255. compilation_count = 0
  256. def count_compilations(*args, **kwargs):
  257. nonlocal compilation_count
  258. compilation_count += 1
  259. old_cache_func = JITFunction.cache_hook
  260. try:
  261. rotary_kernel.cache.clear()
  262. JITFunction.cache_hook = count_compilations
  263. for seqlen in (128, 256):
  264. for nheads in (4, 32):
  265. x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
  266. x.requires_grad_()
  267. cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
  268. out = apply_rotary_emb(x, cos, sin)
  269. out.backward(torch.randn_like(out))
  270. # Only two kernels are expected to be compiled:
  271. # * for the forward pass (conjugate=False)
  272. # * for the backward pass (conjugate=True)
  273. assert compilation_count == 2
  274. finally:
  275. JITFunction.cache_hook = old_cache_func