test_dropout_layer_norm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. import pytest
  5. from einops import rearrange
  6. from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
  7. is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
  8. @pytest.mark.parametrize('has_rowscale', [True, False])
  9. # @pytest.mark.parametrize('has_rowscale', [True])
  10. @pytest.mark.parametrize('has_residual', [True, False])
  11. # @pytest.mark.parametrize('has_residual', [False])
  12. @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
  13. # @pytest.mark.parametrize('dropout_p', [0.0])
  14. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  15. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  16. @pytest.mark.parametrize('input_dtype,residual_dtype',
  17. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  18. (torch.float32, torch.float32)]
  19. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  20. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
  21. @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
  22. def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
  23. dropout_p, has_residual, has_rowscale):
  24. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  25. pytest.skip() # Not supported
  26. # Backward numerical error is high, and this case isn't used
  27. if has_rowscale and not has_residual:
  28. pytest.skip()
  29. device = 'cuda'
  30. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  31. rtol, atol = (1e-3, 1e-4)
  32. # set seed
  33. torch.random.manual_seed(0)
  34. batch_size = 8
  35. seqlen = 512
  36. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  37. requires_grad=True)
  38. x0 = x0_pt.detach().clone().requires_grad_()
  39. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  40. if has_residual:
  41. x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  42. x1 = x1_pt.detach().clone().requires_grad_()
  43. x1_ref = x1_pt.detach().clone().float().requires_grad_()
  44. else:
  45. x1 = None
  46. if has_rowscale:
  47. rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
  48. survival_rate = 0.87
  49. rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
  50. x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
  51. x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
  52. else:
  53. rowscale = None
  54. x0_scaled_pt = x0_pt
  55. x0_scaled_ref = x0_ref
  56. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  57. torch.nn.init.normal_(model_pt.weight)
  58. torch.nn.init.normal_(model_pt.bias)
  59. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  60. model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
  61. with torch.no_grad():
  62. model.weight.copy_(model_pt.weight)
  63. model.bias.copy_(model_pt.bias)
  64. model_ref.weight.copy_(model_pt.weight)
  65. model_ref.bias.copy_(model_pt.bias)
  66. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  67. out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
  68. model.epsilon, rowscale=rowscale,
  69. residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
  70. assert out.dtype == input_dtype
  71. print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
  72. if has_residual:
  73. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
  74. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
  75. else:
  76. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  77. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
  78. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
  79. out_ref = model_ref(residual_ref)
  80. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  81. g = torch.randn_like(out) / batch_size
  82. out_pt.backward(g)
  83. out.backward(g)
  84. out_ref.backward(g)
  85. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  86. if has_residual:
  87. assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
  88. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
  89. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
  90. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  91. @pytest.mark.parametrize('input_dtype,residual_dtype',
  92. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  93. (torch.float32, torch.float32)]
  94. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  95. @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
  96. def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
  97. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  98. pytest.skip() # Not supported
  99. device = 'cuda'
  100. # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
  101. rtol, atol = (1e-3, 1e-4)
  102. dropout_p = 0.37
  103. # set seed
  104. torch.random.manual_seed(0)
  105. batch_size = 32
  106. seqlen = 512
  107. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  108. requires_grad=True)
  109. x0 = x0_pt.detach().clone().requires_grad_()
  110. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  111. x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  112. x1 = x1_pt.detach().clone().requires_grad_()
  113. x1_ref = x1_pt.detach().clone().float().requires_grad_()
  114. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  115. model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
  116. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  117. with torch.no_grad():
  118. model.weight.copy_(model_pt.weight)
  119. model.bias.copy_(model_pt.bias)
  120. model_ref.weight.copy_(model_pt.weight)
  121. model_ref.bias.copy_(model_pt.bias)
  122. model_pt.eval()
  123. model.eval()
  124. model_ref.eval()
  125. out = model(x0, x1)
  126. residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
  127. residual_ref = x0_ref + x1_ref
  128. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
  129. out_ref = model_ref(residual_ref)
  130. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  131. @pytest.mark.parametrize('has_rowscale', [True, False])
  132. @pytest.mark.parametrize('has_residual', [True, False])
  133. @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
  134. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  135. @pytest.mark.parametrize('input_dtype,residual_dtype',
  136. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  137. (torch.float32, torch.float32)]
  138. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  139. # @pytest.mark.parametrize('has_rowscale', [False])
  140. # @pytest.mark.parametrize('has_residual', [True])
  141. # @pytest.mark.parametrize('dropout_p', [0.0])
  142. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  143. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
  144. # @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
  145. @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
  146. def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
  147. dropout_p, has_residual, has_rowscale):
  148. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  149. pytest.skip() # Not supported
  150. # Backward numerical error is high, and this case isn't used
  151. if has_rowscale and not has_residual:
  152. pytest.skip()
  153. device = 'cuda'
  154. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  155. rtol, atol = (1e-3, 2e-4)
  156. # set seed
  157. torch.random.manual_seed(0)
  158. batch_size = 8
  159. seqlen = 512
  160. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  161. requires_grad=True)
  162. x0 = x0_pt.detach().clone().requires_grad_()
  163. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  164. if has_residual:
  165. x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  166. x1 = x1_pt.detach().clone().requires_grad_()
  167. x1_ref = x1_pt.detach().clone().float().requires_grad_()
  168. else:
  169. x1 = None
  170. if has_rowscale:
  171. rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
  172. survival_rate = 0.87
  173. rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
  174. x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
  175. x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
  176. else:
  177. rowscale = None
  178. x0_scaled_pt = x0_pt
  179. x0_scaled_ref = x0_ref
  180. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  181. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  182. model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
  183. dtype=weight_dtype)
  184. with torch.no_grad():
  185. model.weight.copy_(model_pt.weight)
  186. model.bias.copy_(model_pt.bias)
  187. model_ref.weight.copy_(model_pt.weight)
  188. model_ref.bias.copy_(model_pt.bias)
  189. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  190. out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
  191. model.epsilon, rowscale=rowscale, prenorm=True,
  192. residual_in_fp32=residual_in_fp32,
  193. return_dropout_mask=True)
  194. print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
  195. if has_residual:
  196. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
  197. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
  198. else:
  199. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  200. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
  201. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
  202. out_ref = model_ref(residual_ref)
  203. assert out.dtype == input_dtype
  204. assert residual.dtype == residual_dtype
  205. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  206. assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
  207. g = torch.randn_like(out) / batch_size
  208. (out_pt * F.sigmoid(residual_pt)).backward(g)
  209. (out * F.sigmoid(residual)).backward(g)
  210. (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
  211. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  212. if has_residual:
  213. assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
  214. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
  215. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
  216. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  217. @pytest.mark.parametrize('input_dtype,residual_dtype',
  218. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  219. (torch.float32, torch.float32)]
  220. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  221. @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
  222. def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
  223. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  224. pytest.skip() # Not supported
  225. device = 'cuda'
  226. # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
  227. rtol, atol = (1e-3, 1e-4)
  228. dropout_p = 0.37
  229. # set seed
  230. torch.random.manual_seed(0)
  231. batch_size = 32
  232. seqlen = 512
  233. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  234. requires_grad=True)
  235. x0 = x0_pt.detach().clone().requires_grad_()
  236. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  237. x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  238. x1 = x1_pt.detach().clone().requires_grad_()
  239. x1_ref = x1_pt.detach().clone().float().requires_grad_()
  240. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  241. model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
  242. dtype=weight_dtype)
  243. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  244. with torch.no_grad():
  245. model.weight.copy_(model_pt.weight)
  246. model.bias.copy_(model_pt.bias)
  247. model_ref.weight.copy_(model_pt.weight)
  248. model_ref.bias.copy_(model_pt.bias)
  249. model_pt.eval()
  250. model.eval()
  251. model_ref.eval()
  252. out, residual = model(x0, x1)
  253. residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
  254. residual_ref = x0_ref + x1_ref
  255. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
  256. out_ref = model_ref(residual_ref)
  257. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  258. assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4