test_rotary.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import math
  2. import random
  3. import pytest
  4. import torch
  5. import torch.nn.functional as F
  6. from einops import rearrange
  7. from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
  8. from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
  9. from flash_attn.bert_padding import pad_input, unpad_input
  10. is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
  11. def generate_cos_sin(seqlen, rotary_dim, device, dtype):
  12. assert rotary_dim % 2 == 0
  13. angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
  14. cos = torch.cos(angle).to(dtype=dtype)
  15. sin = torch.sin(angle).to(dtype=dtype)
  16. return cos, sin
  17. def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device):
  18. if seqlen_offsets_type == 0:
  19. return 0
  20. elif seqlen_offsets_type is int:
  21. return torch.randint(0, seqlen + 1, (1,)).item()
  22. elif seqlen_offsets_type is torch.Tensor:
  23. return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device)
  24. def index_cos_sin(cos, sin, seqlen_offsets, seqlen):
  25. if isinstance(seqlen_offsets, torch.Tensor):
  26. batch_size = seqlen_offsets.shape[0]
  27. arange = rearrange(torch.arange(seqlen, device=cos.device), "s -> 1 s")
  28. idx = rearrange(seqlen_offsets, "b -> b 1") + arange
  29. cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
  30. sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
  31. else:
  32. cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
  33. sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
  34. return cos_pt, sin_pt
  35. @pytest.mark.parametrize(
  36. "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
  37. )
  38. # @pytest.mark.parametrize('dtype', ([torch.float16]))
  39. @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
  40. # @pytest.mark.parametrize("seqlen_offsets_type", [0])
  41. @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
  42. # @pytest.mark.parametrize('rotary_fraction', [1.0])
  43. @pytest.mark.parametrize("interleaved", [False, True])
  44. # @pytest.mark.parametrize('interleaved', [True])
  45. @pytest.mark.parametrize("inplace", [False, True])
  46. # @pytest.mark.parametrize('inplace', [False])
  47. def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
  48. rtol = 1e-3
  49. batch_size = 32
  50. nheads = 4
  51. seqlen = 217
  52. headdim = 128
  53. device = "cuda"
  54. rotary_dim = int(rotary_fraction * headdim)
  55. torch.manual_seed(42)
  56. x = torch.randn(
  57. batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
  58. )
  59. x_pt = x.detach().clone().requires_grad_()
  60. cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
  61. seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
  62. out = apply_rotary_emb(
  63. x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
  64. )
  65. cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
  66. out_pt = apply_rotary_emb_torch(
  67. x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  68. ).to(dtype=dtype)
  69. print(f"Output max diff: {(out - out_pt).abs().max().item()}")
  70. g = torch.randn_like(out)
  71. g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
  72. out.backward(g)
  73. out_pt.backward(g_pt)
  74. print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
  75. if not inplace:
  76. assert torch.equal(x, x_pt)
  77. # Numerical error if we just do any arithmetic
  78. atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
  79. assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
  80. atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
  81. assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
  82. @pytest.mark.parametrize(
  83. "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
  84. )
  85. # @pytest.mark.parametrize('dtype', ([torch.float16]))
  86. @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
  87. # @pytest.mark.parametrize("seqlen_offsets_type", [0])
  88. @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
  89. # @pytest.mark.parametrize('rotary_fraction', [1.0])
  90. @pytest.mark.parametrize("interleaved", [False, True])
  91. # @pytest.mark.parametrize('interleaved', [False])
  92. def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
  93. rtol = 1e-3
  94. batch_size = 32
  95. nheads = 4
  96. seqlen = 512
  97. headdim = 128
  98. device = "cuda"
  99. rotary_dim = int(rotary_fraction * headdim)
  100. torch.manual_seed(42)
  101. qkv = torch.randn(
  102. batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
  103. )
  104. qkv_pt = qkv.detach().clone().requires_grad_()
  105. cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
  106. seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
  107. out = apply_rotary_emb_qkv_(
  108. qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
  109. )
  110. cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
  111. q_pt = apply_rotary_emb_torch(
  112. qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  113. ).to(dtype=dtype)
  114. k_pt = apply_rotary_emb_torch(
  115. qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  116. ).to(dtype=dtype)
  117. out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
  118. print(f"Output max diff: {(out - out_pt).abs().max().item()}")
  119. g = torch.randn_like(out)
  120. g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
  121. out.backward(g)
  122. out_pt.backward(g_pt)
  123. print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
  124. # Numerical error if we just do any arithmetic
  125. atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
  126. assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
  127. atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
  128. assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
  129. @pytest.mark.parametrize(
  130. "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
  131. )
  132. # @pytest.mark.parametrize('dtype', ([torch.float16]))
  133. @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
  134. # @pytest.mark.parametrize("seqlen_offsets_type", [0])
  135. @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
  136. # @pytest.mark.parametrize('rotary_fraction', [1.0])
  137. @pytest.mark.parametrize("interleaved", [False, True])
  138. # @pytest.mark.parametrize('interleaved', [False])
  139. def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
  140. rtol = 1e-3
  141. batch_size = 32
  142. nheads = 4
  143. seqlen = 781
  144. headdim = 64
  145. device = "cuda"
  146. rotary_dim = int(rotary_fraction * headdim)
  147. torch.manual_seed(42)
  148. kv = torch.randn(
  149. batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
  150. )
  151. kv_pt = kv.detach().clone().requires_grad_()
  152. cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
  153. seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
  154. out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved)
  155. cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
  156. k_pt = apply_rotary_emb_torch(
  157. kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  158. ).to(dtype=dtype)
  159. out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
  160. print(f"Output max diff: {(out - out_pt).abs().max().item()}")
  161. g = torch.randn_like(out)
  162. g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
  163. out.backward(g)
  164. out_pt.backward(g_pt)
  165. print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
  166. # Numerical error if we just do any arithmetic
  167. atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
  168. assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
  169. atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
  170. assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
  171. @pytest.mark.parametrize(
  172. "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
  173. )
  174. # @pytest.mark.parametrize("dtype", ([torch.float16]))
  175. @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
  176. # @pytest.mark.parametrize("seqlen_offsets_type", [0])
  177. @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
  178. # @pytest.mark.parametrize("rotary_fraction", [1.0])
  179. @pytest.mark.parametrize("interleaved", [False, True])
  180. # @pytest.mark.parametrize("interleaved", [True])
  181. @pytest.mark.parametrize("inplace", [False, True])
  182. # @pytest.mark.parametrize("inplace", [False])
  183. def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
  184. rtol = 1e-3
  185. batch_size = 32
  186. nheads = 4
  187. seqlen = 217
  188. headdim = 128
  189. device = "cuda"
  190. rotary_dim = int(rotary_fraction * headdim)
  191. torch.manual_seed(42)
  192. x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
  193. x_pt = x.detach().clone().requires_grad_()
  194. lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
  195. padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
  196. x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask)
  197. x_unpad_clone = x_unpad.clone()
  198. x_unpad = x_unpad.requires_grad_()
  199. cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
  200. seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
  201. out_unpad = apply_rotary_emb(
  202. x_unpad,
  203. cos,
  204. sin,
  205. seqlen_offsets=seqlen_offsets,
  206. interleaved=interleaved,
  207. inplace=inplace,
  208. cu_seqlens=cu_seqlens,
  209. max_seqlen=max_seqlen,
  210. )
  211. out = pad_input(out_unpad, indices, batch_size, seqlen)
  212. cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
  213. out_pt = apply_rotary_emb_torch(
  214. x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
  215. ).to(dtype=dtype)
  216. out_pt = out_pt.masked_fill(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
  217. print(f"Output max diff: {(out - out_pt).abs().max().item()}")
  218. g = torch.randn_like(out)
  219. g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
  220. out.backward(g)
  221. out_pt.backward(g_pt)
  222. x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen)
  223. print(f"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}")
  224. if not inplace:
  225. assert torch.equal(x_unpad, x_unpad_clone)
  226. # Numerical error if we just do any arithmetic
  227. atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
  228. assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
  229. atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
  230. assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
  231. def test_compilation_count():
  232. batch_size = 1
  233. headdim = 128
  234. device = "cuda"
  235. dtype = torch.float16
  236. torch.manual_seed(42)
  237. from triton.runtime.jit import JITFunction
  238. from flash_attn.ops.triton.rotary import rotary_kernel
  239. compilation_count = 0
  240. def count_compilations(*args, **kwargs):
  241. nonlocal compilation_count
  242. compilation_count += 1
  243. old_cache_func = JITFunction.cache_hook
  244. try:
  245. rotary_kernel.cache.clear()
  246. JITFunction.cache_hook = count_compilations
  247. for seqlen in (128, 256):
  248. for nheads in (4, 32):
  249. x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
  250. x.requires_grad_()
  251. cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
  252. out = apply_rotary_emb(x, cos, sin)
  253. out.backward(torch.randn_like(out))
  254. # Only two kernels are expected to be compiled:
  255. # * for the forward pass (conjugate=False)
  256. # * for the backward pass (conjugate=True)
  257. assert compilation_count == 2
  258. finally:
  259. JITFunction.cache_hook = old_cache_func