test_fused_dense.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import math
  2. from functools import partial
  3. import pytest
  4. import torch
  5. import torch.nn.functional as F
  6. from einops import rearrange
  7. from flash_attn.ops.fused_dense import FusedDense, FusedMLP
  8. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  9. @pytest.mark.parametrize("return_residual", [False, True])
  10. @pytest.mark.parametrize("has_bias", [True, False])
  11. @pytest.mark.parametrize("out_features", [1024, 4096])
  12. @pytest.mark.parametrize("in_features", [1024, 4096])
  13. def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype):
  14. device = "cuda"
  15. rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
  16. # set seed
  17. torch.random.manual_seed(0)
  18. batch_size = 8
  19. seqlen = 512
  20. x_pt = torch.randn(
  21. batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True
  22. )
  23. x = x_pt.detach().clone().requires_grad_()
  24. model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
  25. model = FusedDense(
  26. in_features,
  27. out_features,
  28. bias=has_bias,
  29. return_residual=return_residual,
  30. device=device,
  31. dtype=dtype,
  32. )
  33. with torch.no_grad():
  34. model.weight.copy_(model_pt.weight)
  35. if has_bias:
  36. model.bias.copy_(model_pt.bias)
  37. out_pt = model_pt(x_pt)
  38. if not return_residual:
  39. out = model(x)
  40. else:
  41. out, x_copy = model(x)
  42. x_copy = (
  43. x_copy[..., :out_features]
  44. if out_features < in_features
  45. else F.pad(x_copy, (0, out_features - in_features))
  46. )
  47. x_pt_copy = (
  48. x_pt[..., :out_features]
  49. if out_features < in_features
  50. else F.pad(x_pt, (0, out_features - in_features))
  51. )
  52. # Just add some random function of the residual
  53. out_pt = out_pt + F.gelu(x_pt_copy)
  54. out = out + F.gelu(x_copy)
  55. # with torch.no_grad():
  56. # out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
  57. assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
  58. # If we don't divide by batch_size, the gradient gets a bit too large.
  59. g = torch.randn_like(out) / 32
  60. out_pt.backward(g)
  61. out.backward(g)
  62. assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
  63. # The error for d_weight and d_bias is quite a bit higher
  64. assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
  65. if has_bias:
  66. assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
  67. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  68. # @pytest.mark.parametrize('dtype', [torch.float16])
  69. @pytest.mark.parametrize("heuristic", ["auto", -1])
  70. # @pytest.mark.parametrize('heuristic', ['auto'])
  71. @pytest.mark.parametrize("checkpoint_lvl", [0, 1, 2])
  72. # @pytest.mark.parametrize('checkpoint_lvl', [1])
  73. @pytest.mark.parametrize("return_residual", [False, True])
  74. # @pytest.mark.parametrize('return_residual', [False])
  75. @pytest.mark.parametrize("has_bias2", [True, False])
  76. @pytest.mark.parametrize("has_bias1", [True, False])
  77. # @pytest.mark.parametrize('has_bias2', [True])
  78. # @pytest.mark.parametrize('has_bias1', [True])
  79. @pytest.mark.parametrize("activation", ["gelu_approx", "relu"])
  80. # @pytest.mark.parametrize('activation', ['relu'])
  81. @pytest.mark.parametrize("out_features", [1024, 4096])
  82. @pytest.mark.parametrize("in_features", [1024, 4096])
  83. # @pytest.mark.parametrize('out_features', [4096])
  84. # @pytest.mark.parametrize('in_features', [1024])
  85. def test_fused_mlp(
  86. in_features,
  87. out_features,
  88. activation,
  89. has_bias1,
  90. has_bias2,
  91. return_residual,
  92. checkpoint_lvl,
  93. heuristic,
  94. dtype,
  95. ):
  96. device = "cuda"
  97. rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
  98. # set seed
  99. torch.random.manual_seed(0)
  100. batch_size = 8
  101. seqlen = 512
  102. x_pt = torch.randn(
  103. batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True
  104. )
  105. x = x_pt.detach().clone().requires_grad_()
  106. model_pt_fc1 = torch.nn.Linear(
  107. in_features, out_features, bias=has_bias1, device=device, dtype=dtype
  108. )
  109. model_pt_fc2 = torch.nn.Linear(
  110. out_features, in_features, bias=has_bias2, device=device, dtype=dtype
  111. )
  112. model = FusedMLP(
  113. in_features,
  114. out_features,
  115. in_features,
  116. activation=activation,
  117. bias1=has_bias1,
  118. bias2=has_bias2,
  119. return_residual=return_residual,
  120. checkpoint_lvl=checkpoint_lvl,
  121. heuristic=heuristic,
  122. device=device,
  123. dtype=dtype,
  124. )
  125. with torch.no_grad():
  126. model.fc1.weight.copy_(model_pt_fc1.weight)
  127. if has_bias1:
  128. model.fc1.bias.copy_(model_pt_fc1.bias)
  129. model.fc2.weight.copy_(model_pt_fc2.weight)
  130. if has_bias2:
  131. model.fc2.bias.copy_(model_pt_fc2.bias)
  132. activation_fn = (
  133. partial(F.gelu, approximate="tanh")
  134. if activation == "gelu_approx"
  135. else partial(F.relu, inplace=True)
  136. )
  137. out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt)))
  138. if not return_residual:
  139. out = model(x)
  140. else:
  141. out, x_copy = model(x)
  142. # Just add some random function of the residual
  143. out_pt = out_pt + F.gelu(x_pt)
  144. out = out + F.gelu(x_copy)
  145. assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
  146. # If we don't divide by batch_size, the gradient gets a bit too large.
  147. g = torch.randn_like(out) / 32
  148. out_pt.backward(g)
  149. out.backward(g)
  150. # The error for relu is higher still
  151. if activation == "relu":
  152. atol = 1e-1 if dtype == torch.bfloat16 else 5e-2
  153. assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
  154. # The error for d_weight and d_bias is quite a bit higher
  155. assert torch.allclose(
  156. model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10
  157. )
  158. if has_bias1:
  159. assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
  160. assert torch.allclose(
  161. model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10
  162. )
  163. if has_bias2:
  164. assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)