test_rotary.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. import torch
  4. import torch.nn.functional as F
  5. import pytest
  6. from einops import rearrange
  7. from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
  8. from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_neox
  9. from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
  10. from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
  11. from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_qkv_
  12. from flash_attn.layers.rotary import RotaryEmbedding
  13. # NeoX-style rotary embedding
  14. @pytest.mark.parametrize('seqlen_offset', [0, 711])
  15. @pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0])
  16. def test_rotary(rotary_emb_fraction, seqlen_offset):
  17. device = 'cuda'
  18. dtype = torch.float16
  19. rtol, atol = (1e-3, 5e-3)
  20. # set seed
  21. torch.random.manual_seed(0)
  22. batch_size = 8
  23. seqlen_total = 2048
  24. seqlen = seqlen_total - seqlen_offset
  25. nheads = 16
  26. headdim = 128
  27. rotary_dim = int(headdim * rotary_emb_fraction)
  28. qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
  29. requires_grad=True)
  30. qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
  31. rotary = RotaryEmbedding(rotary_dim, device=device)
  32. rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
  33. # Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
  34. cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
  35. cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
  36. q_pt = rearrange(qkv[:, :, 0, :, :rotary_dim],
  37. 'b s h d -> b h s d').detach().clone().requires_grad_(True)
  38. k_pt = rearrange(qkv[:, :, 1, :, :rotary_dim],
  39. 'b s h d -> b h s d').detach().clone().requires_grad_(True)
  40. q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
  41. out = rotary(qkv, seqlen_offset=seqlen_offset)
  42. assert torch.allclose(rotary._cos_cached, cos_neox[..., :rotary_dim // 2].to(dtype=dtype),
  43. rtol=rtol, atol=atol)
  44. assert torch.allclose(rotary._sin_cached, sin_neox[..., :rotary_dim // 2].to(dtype=dtype),
  45. rtol=rtol, atol=atol)
  46. assert torch.allclose(rearrange(q_neox, 'b h s d -> b s h d'), out[:, :, 0, :, :rotary_dim],
  47. rtol=rtol, atol=atol)
  48. assert torch.allclose(rearrange(k_neox, 'b h s d -> b s h d'), out[:, :, 1, :, :rotary_dim],
  49. rtol=rtol, atol=atol)
  50. assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
  51. assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
  52. g = torch.randn_like(out)
  53. g_og = g.clone().detach() # Our implementation modifies g inplace
  54. out.backward(g)
  55. q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], 'b s h d -> b h s d'))
  56. k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], 'b s h d -> b h s d'))
  57. assert torch.allclose(rearrange(q_pt.grad, 'b h s d -> b s h d'),
  58. qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
  59. assert torch.allclose(rearrange(k_pt.grad, 'b h s d -> b s h d'),
  60. qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
  61. assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
  62. assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
  63. # GPT-J-style rotary embedding
  64. @pytest.mark.parametrize('seqlen_offset', [0, 711])
  65. @pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0])
  66. def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
  67. device = 'cuda'
  68. dtype = torch.float16
  69. rtol, atol = (1e-3, 5e-3)
  70. # set seed
  71. torch.random.manual_seed(0)
  72. batch_size = 8
  73. seqlen_total = 2048
  74. seqlen = seqlen_total - seqlen_offset
  75. nheads = 16
  76. headdim = 128
  77. rotary_dim = int(headdim * rotary_emb_fraction)
  78. qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
  79. requires_grad=True)
  80. qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
  81. rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
  82. sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
  83. sincos_gptj = tuple(x.to(dtype=dtype) for x in sincos_gptj)
  84. q_pt = qkv[:, :, 0, :, :rotary_dim].detach().clone().requires_grad_(True)
  85. k_pt = qkv[:, :, 1, :, :rotary_dim].detach().clone().requires_grad_(True)
  86. q_gptj = apply_rotary_pos_emb_gptj(q_pt, sincos_gptj, offset=seqlen_offset)
  87. k_gptj = apply_rotary_pos_emb_gptj(k_pt, sincos_gptj, offset=seqlen_offset)
  88. out = rotary(qkv, seqlen_offset=seqlen_offset)
  89. assert torch.allclose(rotary._cos_cached, sincos_gptj[1], rtol=rtol, atol=atol)
  90. assert torch.allclose(rotary._sin_cached, sincos_gptj[0], rtol=rtol, atol=atol)
  91. assert torch.allclose(q_gptj, out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
  92. assert torch.allclose(k_gptj, out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
  93. assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
  94. assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
  95. g = torch.randn_like(out)
  96. g_og = g.clone().detach() # Our implementation modifies g inplace
  97. out.backward(g)
  98. q_gptj.backward(g_og[:, :, 0, :, :rotary_dim])
  99. k_gptj.backward(g_og[:, :, 1, :, :rotary_dim])
  100. assert torch.allclose(q_pt.grad, qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
  101. assert torch.allclose(k_pt.grad, qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
  102. assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
  103. assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])