1
0

test_layer_norm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # Copyright (c) 2024, Tri Dao.
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn.ops.triton.layer_norm import (
  7. layer_norm_fn,
  8. layer_norm_ref,
  9. rms_norm_ref,
  10. layer_norm_linear_fn,
  11. )
  12. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  13. @pytest.mark.parametrize("has_weight1", [False, True])
  14. # @pytest.mark.parametrize("has_weight1", [True])
  15. @pytest.mark.parametrize("has_x1", [False, True])
  16. # @pytest.mark.parametrize("has_x1", [False])
  17. @pytest.mark.parametrize("has_rowscale", [False, True])
  18. # @pytest.mark.parametrize("has_rowscale", [False])
  19. @pytest.mark.parametrize("dropout_p", [0.0, 0.27])
  20. # @pytest.mark.parametrize("dropout_p", [0.0])
  21. @pytest.mark.parametrize("prenorm", [True, False])
  22. # @pytest.mark.parametrize("prenorm", [False])
  23. @pytest.mark.parametrize("is_rms_norm", [False, True])
  24. # @pytest.mark.parametrize("is_rms_norm", [True])
  25. @pytest.mark.parametrize("has_residual", [True, False])
  26. # @pytest.mark.parametrize("has_residual", [False])
  27. @pytest.mark.parametrize(
  28. "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
  29. )
  30. # @pytest.mark.parametrize("weight_dtype", [torch.float32])
  31. @pytest.mark.parametrize(
  32. "input_dtype,residual_dtype",
  33. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  34. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  35. )
  36. # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)])
  37. @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096])
  38. # @pytest.mark.parametrize("hidden_size", [256])
  39. def test_layer_norm(
  40. hidden_size,
  41. input_dtype,
  42. residual_dtype,
  43. weight_dtype,
  44. has_residual,
  45. is_rms_norm,
  46. prenorm,
  47. dropout_p,
  48. has_rowscale,
  49. has_x1,
  50. has_weight1,
  51. ):
  52. if has_rowscale and has_x1:
  53. pytest.skip("Not supported")
  54. device = "cuda"
  55. if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
  56. atol = 5e-2
  57. elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
  58. atol = 1e-2
  59. else:
  60. atol = 1e-4
  61. # set seed
  62. torch.random.manual_seed(0)
  63. batch_size = 8
  64. seqlen = 512
  65. layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
  66. allclose = (
  67. # Sometimes x0_pt.grad is NaN
  68. lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
  69. <= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
  70. or (
  71. # Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
  72. # by multiply and divide by 0.3
  73. (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0
  74. and (x - x_ref).abs().max()
  75. <= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol
  76. )
  77. )
  78. x0 = torch.randn(
  79. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  80. )
  81. x0_pt = x0.detach().clone().requires_grad_()
  82. x0_ref = x0.detach().clone().requires_grad_()
  83. if has_residual:
  84. res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  85. res_pt = res.detach().clone().requires_grad_()
  86. res_ref = res.detach().clone().requires_grad_()
  87. else:
  88. res, res_pt, res_ref = None, None, None
  89. weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  90. if not is_rms_norm:
  91. bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  92. else:
  93. bias = None
  94. weight_pt = weight.detach().clone().requires_grad_()
  95. weight_ref = weight.detach().clone().requires_grad_()
  96. bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
  97. bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
  98. if has_x1:
  99. x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)
  100. x1_pt = x1.detach().clone().requires_grad_()
  101. x1_ref = x1.detach().clone().requires_grad_()
  102. else:
  103. x1, x1_pt, x1_ref = None, None, None
  104. if has_weight1:
  105. weight1 = torch.randn(
  106. hidden_size, device=device, dtype=weight_dtype, requires_grad=True
  107. )
  108. weight1_pt = weight1.detach().clone().requires_grad_()
  109. weight1_ref = weight1.detach().clone().requires_grad_()
  110. if not is_rms_norm:
  111. bias1 = torch.randn(
  112. hidden_size, device=device, dtype=weight_dtype, requires_grad=True
  113. )
  114. else:
  115. bias1 = None
  116. bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
  117. bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None
  118. else:
  119. weight1, weight1_pt, weight1_ref = None, None, None
  120. bias1, bias1_pt, bias1_ref = None, None, None
  121. rowscale = (
  122. torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)
  123. if has_rowscale
  124. else None
  125. )
  126. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  127. out, *rest = layer_norm_fn(
  128. x0,
  129. weight,
  130. bias,
  131. residual=res,
  132. x1=x1,
  133. weight1=weight1,
  134. bias1=bias1,
  135. eps=1e-6,
  136. dropout_p=dropout_p,
  137. rowscale=rowscale,
  138. prenorm=prenorm,
  139. residual_in_fp32=residual_in_fp32,
  140. is_rms_norm=is_rms_norm,
  141. return_dropout_mask=True,
  142. )
  143. dropout_mask = rest[-2] if dropout_p > 0.0 else None
  144. dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None
  145. out_pt = layer_norm_ref_fn(
  146. x0_pt,
  147. weight_pt,
  148. bias_pt,
  149. residual=res_pt,
  150. x1=x1_pt,
  151. weight1=weight1_pt,
  152. bias1=bias1_pt,
  153. eps=1e-6,
  154. dropout_p=dropout_p,
  155. rowscale=rowscale,
  156. prenorm=prenorm,
  157. dropout_mask=dropout_mask,
  158. dropout_mask1=dropout_mask1,
  159. )
  160. out_ref = layer_norm_ref_fn(
  161. x0_ref,
  162. weight_ref,
  163. bias_ref,
  164. residual=res_ref,
  165. x1=x1_ref,
  166. weight1=weight1_ref,
  167. bias1=bias1_ref,
  168. eps=1e-6,
  169. dropout_p=dropout_p,
  170. rowscale=rowscale,
  171. prenorm=prenorm,
  172. dropout_mask=dropout_mask,
  173. dropout_mask1=dropout_mask1,
  174. upcast=True,
  175. )
  176. if not has_weight1:
  177. if prenorm:
  178. residual = rest[0]
  179. out_pt, residual_pt = out_pt
  180. out_ref, residual_ref = out_ref
  181. out1, out1_pt, out1_ref = None, None, None
  182. else:
  183. out1 = rest.pop(0)
  184. if prenorm:
  185. residual = rest[0]
  186. out_pt, out1_pt, residual_pt = out_pt
  187. out_ref, out1_ref, residual_ref = out_ref
  188. else:
  189. out_pt, out1_pt = out_pt
  190. out_ref, out1_ref = out_ref
  191. assert out.dtype == input_dtype
  192. if prenorm:
  193. assert residual.dtype == residual_dtype
  194. assert allclose(residual, residual_pt, residual_ref)
  195. assert allclose(out, out_pt, out_ref)
  196. if out1 is not None:
  197. assert out1.dtype == input_dtype
  198. assert allclose(out1, out1_pt, out1_ref)
  199. if dropout_mask is not None:
  200. dropout_fraction = 1.0 - dropout_mask.float().mean()
  201. assert abs(dropout_fraction - dropout_p) < 0.01
  202. if dropout_mask1 is not None:
  203. dropout_fraction = 1.0 - dropout_mask1.float().mean()
  204. assert abs(dropout_fraction - dropout_p) < 0.01
  205. assert not torch.equal(dropout_mask, dropout_mask1)
  206. g = torch.randn_like(out) / batch_size
  207. if has_weight1:
  208. out = out * F.gelu(out1)
  209. out_pt = out_pt * F.gelu(out1_pt)
  210. out_ref = out_ref * F.gelu(out1_ref)
  211. if not prenorm:
  212. out.backward(g)
  213. out_pt.backward(g)
  214. out_ref.backward(g)
  215. else:
  216. (out * F.sigmoid(residual)).backward(g)
  217. (out_pt * F.sigmoid(residual_pt)).backward(g)
  218. (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
  219. assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
  220. if has_residual:
  221. assert allclose(res.grad, res_pt.grad, res_ref.grad)
  222. if has_x1:
  223. assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)
  224. assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
  225. if bias is not None:
  226. assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
  227. if has_weight1:
  228. assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)
  229. if bias1 is not None:
  230. assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)
  231. @pytest.mark.parametrize("prenorm", [True, False])
  232. # @pytest.mark.parametrize("prenorm", [True])
  233. @pytest.mark.parametrize("is_rms_norm", [False, True])
  234. # @pytest.mark.parametrize("is_rms_norm", [True])
  235. @pytest.mark.parametrize("has_residual", [True, False])
  236. # @pytest.mark.parametrize("has_residual", [False])
  237. @pytest.mark.parametrize("weight_dtype", [torch.float32])
  238. @pytest.mark.parametrize(
  239. "input_dtype,residual_dtype",
  240. [(torch.float16, torch.float16), (torch.float16, torch.float32)]
  241. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  242. )
  243. # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
  244. @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])
  245. # @pytest.mark.parametrize("hidden_size", [256])
  246. def test_layer_norm_linear(
  247. hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm
  248. ):
  249. device = "cuda"
  250. if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
  251. atol = 5e-2
  252. elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
  253. atol = 1e-2
  254. else:
  255. atol = 1e-4
  256. # set seed
  257. torch.random.manual_seed(0)
  258. batch_size = 4
  259. seqlen = 512
  260. # batch_size = 1
  261. # seqlen = 1
  262. layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
  263. allclose = (
  264. lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
  265. <= 2 * (x_pt - x_ref).abs().max() + atol
  266. )
  267. x0 = torch.randn(
  268. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  269. )
  270. x0_pt = x0.detach().clone().requires_grad_()
  271. x0_ref = x0.detach().clone().requires_grad_()
  272. if has_residual:
  273. res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  274. res_pt = res.detach().clone().requires_grad_()
  275. res_ref = res.detach().clone().requires_grad_()
  276. else:
  277. res, res_pt, res_ref = None, None, None
  278. norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  279. if not is_rms_norm:
  280. norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  281. else:
  282. norm_bias = None
  283. norm_weight_pt = norm_weight.detach().clone().requires_grad_()
  284. norm_weight_ref = norm_weight.detach().clone().requires_grad_()
  285. norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
  286. norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
  287. linear_weight = torch.empty(
  288. 2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True
  289. )
  290. torch.nn.init.xavier_uniform_(linear_weight)
  291. if not is_rms_norm:
  292. linear_bias = torch.randn(
  293. 2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True
  294. )
  295. else:
  296. linear_bias = None
  297. linear_weight_pt = linear_weight.detach().clone().requires_grad_()
  298. linear_weight_ref = linear_weight.detach().clone().requires_grad_()
  299. linear_bias_pt = (
  300. linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
  301. )
  302. linear_bias_ref = (
  303. linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
  304. )
  305. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  306. with torch.autocast(device_type="cuda", dtype=input_dtype):
  307. out, *rest = layer_norm_linear_fn(
  308. x0,
  309. norm_weight,
  310. norm_bias,
  311. linear_weight,
  312. linear_bias,
  313. residual=res,
  314. eps=1e-6,
  315. prenorm=prenorm,
  316. residual_in_fp32=residual_in_fp32,
  317. is_rms_norm=is_rms_norm,
  318. )
  319. out_pt, *rest_pt = layer_norm_ref_fn(
  320. x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm
  321. )
  322. with torch.autocast(device_type="cuda", dtype=input_dtype):
  323. out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt)
  324. out_ref, *rest_ref = layer_norm_ref_fn(
  325. x0_ref,
  326. norm_weight_ref,
  327. norm_bias_ref,
  328. residual=res_ref,
  329. eps=1e-6,
  330. prenorm=prenorm,
  331. upcast=True,
  332. )
  333. out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref)
  334. if prenorm:
  335. residual = rest[0]
  336. residual_pt = rest_pt[0]
  337. residual_ref = rest_ref[0]
  338. assert out.dtype == input_dtype
  339. if prenorm:
  340. assert residual.dtype == residual_dtype
  341. assert allclose(residual, residual_pt, residual_ref)
  342. assert allclose(out, out_pt, out_ref)
  343. g = torch.randn_like(out) / batch_size
  344. out.backward(g)
  345. out_pt.backward(g)
  346. out_ref.backward(g)
  347. assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
  348. if has_residual:
  349. assert allclose(res.grad, res_pt.grad, res_ref.grad)
  350. assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad)
  351. if norm_bias is not None:
  352. assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad)
  353. assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad)
  354. if linear_bias is not None:
  355. assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)