test_rotary.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. import pytest
  4. import torch
  5. import torch.nn.functional as F
  6. from einops import rearrange
  7. from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_
  8. from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
  9. from transformers.models.gpt_neox.modeling_gpt_neox import (
  10. apply_rotary_pos_emb as apply_rotary_pos_emb_neox,
  11. )
  12. from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
  13. from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
  14. # NeoX-style rotary embedding
  15. @pytest.mark.parametrize("seqlen_offset", [0, 711])
  16. @pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
  17. def test_rotary(rotary_emb_fraction, seqlen_offset):
  18. device = "cuda"
  19. dtype = torch.float16
  20. rtol, atol = (1e-3, 5e-3)
  21. # set seed
  22. torch.random.manual_seed(0)
  23. batch_size = 8
  24. seqlen_total = 2048
  25. seqlen = seqlen_total - seqlen_offset
  26. nheads = 16
  27. headdim = 128
  28. rotary_dim = int(headdim * rotary_emb_fraction)
  29. qkv = torch.randn(
  30. batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
  31. )
  32. qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
  33. rotary = RotaryEmbedding(rotary_dim, device=device)
  34. rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
  35. # Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
  36. cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
  37. cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
  38. q_pt = (
  39. rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
  40. .detach()
  41. .clone()
  42. .requires_grad_(True)
  43. )
  44. k_pt = (
  45. rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
  46. .detach()
  47. .clone()
  48. .requires_grad_(True)
  49. )
  50. q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
  51. out = rotary(qkv, seqlen_offset=seqlen_offset)
  52. assert torch.allclose(
  53. rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
  54. )
  55. assert torch.allclose(
  56. rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
  57. )
  58. assert torch.allclose(
  59. rearrange(q_neox, "b h s d -> b s h d"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol
  60. )
  61. assert torch.allclose(
  62. rearrange(k_neox, "b h s d -> b s h d"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol
  63. )
  64. assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
  65. assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
  66. g = torch.randn_like(out)
  67. g_og = g.clone().detach() # Our implementation modifies g inplace
  68. out.backward(g)
  69. q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], "b s h d -> b h s d"))
  70. k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], "b s h d -> b h s d"))
  71. assert torch.allclose(
  72. rearrange(q_pt.grad, "b h s d -> b s h d"),
  73. qkv.grad[:, :, 0, :, :rotary_dim],
  74. rtol=rtol,
  75. atol=atol,
  76. )
  77. assert torch.allclose(
  78. rearrange(k_pt.grad, "b h s d -> b s h d"),
  79. qkv.grad[:, :, 1, :, :rotary_dim],
  80. rtol=rtol,
  81. atol=atol,
  82. )
  83. assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
  84. assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
  85. # GPT-J-style rotary embedding
  86. @pytest.mark.parametrize("seqlen_offset", [0, 711])
  87. @pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
  88. def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
  89. device = "cuda"
  90. dtype = torch.float16
  91. rtol, atol = (1e-3, 5e-3)
  92. # set seed
  93. torch.random.manual_seed(0)
  94. batch_size = 8
  95. seqlen_total = 2048
  96. seqlen = seqlen_total - seqlen_offset
  97. nheads = 16
  98. headdim = 128
  99. rotary_dim = int(headdim * rotary_emb_fraction)
  100. qkv = torch.randn(
  101. batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
  102. )
  103. qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
  104. rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
  105. sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
  106. sincos_gptj = tuple(x.to(dtype=dtype) for x in sincos_gptj)
  107. q_pt = qkv[:, :, 0, :, :rotary_dim].detach().clone().requires_grad_(True)
  108. k_pt = qkv[:, :, 1, :, :rotary_dim].detach().clone().requires_grad_(True)
  109. q_gptj = apply_rotary_pos_emb_gptj(q_pt, sincos_gptj, offset=seqlen_offset)
  110. k_gptj = apply_rotary_pos_emb_gptj(k_pt, sincos_gptj, offset=seqlen_offset)
  111. out = rotary(qkv, seqlen_offset=seqlen_offset)
  112. assert torch.allclose(rotary._cos_cached, sincos_gptj[1], rtol=rtol, atol=atol)
  113. assert torch.allclose(rotary._sin_cached, sincos_gptj[0], rtol=rtol, atol=atol)
  114. assert torch.allclose(q_gptj, out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
  115. assert torch.allclose(k_gptj, out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
  116. assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
  117. assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
  118. g = torch.randn_like(out)
  119. g_og = g.clone().detach() # Our implementation modifies g inplace
  120. out.backward(g)
  121. q_gptj.backward(g_og[:, :, 0, :, :rotary_dim])
  122. k_gptj.backward(g_og[:, :, 1, :, :rotary_dim])
  123. assert torch.allclose(q_pt.grad, qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
  124. assert torch.allclose(k_pt.grad, qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
  125. assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
  126. assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])