test_dropout_layer_norm.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189
  1. import math
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn.ops.layer_norm import (
  7. DropoutAddLayerNorm,
  8. dropout_add_layer_norm,
  9. dropout_add_layer_norm_parallel_residual,
  10. dropout_add_layer_norm_subset,
  11. )
  12. from flash_attn.ops.rms_norm import (
  13. DropoutAddRMSNorm,
  14. dropout_add_rms_norm,
  15. dropout_add_rms_norm_parallel_residual,
  16. dropout_add_rms_norm_subset,
  17. )
  18. try:
  19. from apex.normalization import FusedRMSNorm
  20. from apex.normalization.fused_layer_norm import fused_rms_norm_affine
  21. except:
  22. FusedRMSNorm, fused_rms_norm_affine = None, None
  23. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  24. @pytest.mark.parametrize("is_rms_norm", [False, True])
  25. @pytest.mark.parametrize("has_colscale", [True, False])
  26. # @pytest.mark.parametrize('has_colscale', [False])
  27. @pytest.mark.parametrize("has_rowscale", [True, False])
  28. # @pytest.mark.parametrize('has_rowscale', [True])
  29. @pytest.mark.parametrize("has_residual", [True, False])
  30. # @pytest.mark.parametrize('has_residual', [False])
  31. @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
  32. # @pytest.mark.parametrize('dropout_p', [0.0])
  33. @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
  34. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  35. @pytest.mark.parametrize(
  36. "input_dtype,residual_dtype",
  37. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  38. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  39. )
  40. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
  41. @pytest.mark.parametrize(
  42. "hidden_size",
  43. [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
  44. )
  45. # @pytest.mark.parametrize('hidden_size', [256])
  46. def test_dropout_layer_norm_training(
  47. hidden_size,
  48. input_dtype,
  49. residual_dtype,
  50. weight_dtype,
  51. dropout_p,
  52. has_residual,
  53. has_rowscale,
  54. has_colscale,
  55. is_rms_norm,
  56. ):
  57. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  58. pytest.skip() # Not supported
  59. if is_rms_norm and FusedRMSNorm is None:
  60. pytest.skip() # We need Apex's FusedRMSNorm to test
  61. layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
  62. our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
  63. our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
  64. device = "cuda"
  65. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  66. rtol, atol = (1e-3, 1e-4)
  67. # set seed
  68. torch.random.manual_seed(0)
  69. batch_size = 8
  70. seqlen = 512
  71. x0_pt = torch.randn(
  72. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  73. )
  74. x0 = x0_pt.detach().clone().requires_grad_()
  75. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  76. if has_colscale:
  77. colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  78. colscale_pt = colscale.detach().clone().requires_grad_()
  79. colscale_ref = colscale.detach().clone().float().requires_grad_()
  80. else:
  81. colscale = None
  82. if has_residual:
  83. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  84. res = res_pt.detach().clone().requires_grad_()
  85. res_ref = res_pt.detach().clone().float().requires_grad_()
  86. else:
  87. res = None
  88. if has_rowscale:
  89. rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
  90. survival_rate = 0.87
  91. rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
  92. x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1")
  93. x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1")
  94. else:
  95. rowscale = None
  96. x0_scaled_pt = x0_pt
  97. x0_scaled_ref = x0_ref
  98. if has_colscale:
  99. x0_scaled_pt = x0_scaled_pt * colscale_pt
  100. x0_scaled_ref = x0_scaled_ref * colscale_ref
  101. model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
  102. torch.nn.init.normal_(model_pt.weight)
  103. if not is_rms_norm:
  104. torch.nn.init.normal_(model_pt.bias)
  105. model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
  106. model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
  107. with torch.no_grad():
  108. model.weight.copy_(model_pt.weight)
  109. model_ref.weight.copy_(model_pt.weight)
  110. if not is_rms_norm:
  111. model.bias.copy_(model_pt.bias)
  112. model_ref.bias.copy_(model_pt.bias)
  113. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  114. out, dmask = our_layer_norm_func(
  115. x0,
  116. res,
  117. model.weight,
  118. model.bias,
  119. model.p,
  120. model.eps,
  121. rowscale=rowscale,
  122. layerscale=colscale,
  123. residual_in_fp32=residual_in_fp32,
  124. return_dropout_mask=True,
  125. )
  126. assert out.dtype == input_dtype
  127. print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
  128. if has_residual:
  129. residual_pt = (
  130. (x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()
  131. ).to(dtype=residual_dtype)
  132. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
  133. else:
  134. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(
  135. dtype=residual_dtype
  136. )
  137. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
  138. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
  139. out_ref = model_ref(residual_ref)
  140. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  141. g = torch.randn_like(out) / batch_size
  142. out_pt.backward(g)
  143. out.backward(g)
  144. out_ref.backward(g)
  145. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  146. if has_residual:
  147. assert (res.grad - res_ref.grad).abs().max() <= 4 * (
  148. res_pt.grad - res_ref.grad
  149. ).abs().max() + 1e-4
  150. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (
  151. model_pt.weight.grad - model_ref.weight.grad
  152. ).abs().max() + 3e-5
  153. if not is_rms_norm:
  154. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
  155. model_pt.bias.grad - model_ref.bias.grad
  156. ).abs().max() + 3e-5
  157. if has_colscale:
  158. assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
  159. colscale_pt.grad - colscale_ref.grad
  160. ).abs().max() + 2e-4
  161. @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
  162. @pytest.mark.parametrize(
  163. "input_dtype,residual_dtype",
  164. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  165. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  166. )
  167. @pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
  168. def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
  169. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  170. pytest.skip() # Not supported
  171. device = "cuda"
  172. # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
  173. rtol, atol = (1e-3, 1e-4)
  174. dropout_p = 0.37
  175. # set seed
  176. torch.random.manual_seed(0)
  177. batch_size = 32
  178. seqlen = 512
  179. x0_pt = torch.randn(
  180. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  181. )
  182. x0 = x0_pt.detach().clone().requires_grad_()
  183. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  184. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  185. res = res_pt.detach().clone().requires_grad_()
  186. res_ref = res_pt.detach().clone().float().requires_grad_()
  187. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  188. torch.nn.init.normal_(model_pt.weight)
  189. torch.nn.init.normal_(model_pt.bias)
  190. model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
  191. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  192. with torch.no_grad():
  193. model.weight.copy_(model_pt.weight)
  194. model.bias.copy_(model_pt.bias)
  195. model_ref.weight.copy_(model_pt.weight)
  196. model_ref.bias.copy_(model_pt.bias)
  197. model_pt.eval()
  198. model.eval()
  199. model_ref.eval()
  200. out = model(x0, res)
  201. residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
  202. residual_ref = x0_ref + res_ref
  203. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
  204. out_ref = model_ref(residual_ref)
  205. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  206. @pytest.mark.parametrize("is_rms_norm", [False, True])
  207. @pytest.mark.parametrize("has_colscale", [True, False])
  208. @pytest.mark.parametrize("has_rowscale", [True, False])
  209. @pytest.mark.parametrize("has_residual", [True, False])
  210. @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
  211. @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
  212. @pytest.mark.parametrize(
  213. "input_dtype,residual_dtype",
  214. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  215. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  216. )
  217. # @pytest.mark.parametrize('has_colscale', [True])
  218. # @pytest.mark.parametrize('has_rowscale', [False])
  219. # @pytest.mark.parametrize('has_residual', [True])
  220. # @pytest.mark.parametrize('dropout_p', [0.0])
  221. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  222. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
  223. @pytest.mark.parametrize(
  224. "hidden_size",
  225. [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
  226. )
  227. # @pytest.mark.parametrize('hidden_size', [256])
  228. def test_dropout_layer_norm_prenorm_training(
  229. hidden_size,
  230. input_dtype,
  231. residual_dtype,
  232. weight_dtype,
  233. dropout_p,
  234. has_residual,
  235. has_rowscale,
  236. has_colscale,
  237. is_rms_norm,
  238. ):
  239. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  240. pytest.skip() # Not supported
  241. if is_rms_norm and FusedRMSNorm is None:
  242. pytest.skip() # We need Apex's FusedRMSNorm to test
  243. layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
  244. our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
  245. our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
  246. device = "cuda"
  247. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  248. rtol, atol = (1e-3, 2e-4)
  249. # set seed
  250. torch.random.manual_seed(0)
  251. batch_size = 8
  252. seqlen = 512
  253. x0_pt = torch.randn(
  254. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  255. )
  256. x0 = x0_pt.detach().clone().requires_grad_()
  257. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  258. if has_colscale:
  259. colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  260. colscale_pt = colscale.detach().clone().requires_grad_()
  261. colscale_ref = colscale.detach().clone().float().requires_grad_()
  262. else:
  263. colscale = None
  264. if has_residual:
  265. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  266. res = res_pt.detach().clone().requires_grad_()
  267. res_ref = res_pt.detach().clone().float().requires_grad_()
  268. else:
  269. res = None
  270. if has_rowscale:
  271. rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
  272. survival_rate = 0.87
  273. rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
  274. x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1")
  275. x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1")
  276. else:
  277. rowscale = None
  278. x0_scaled_pt = x0_pt
  279. x0_scaled_ref = x0_ref
  280. if has_colscale:
  281. x0_scaled_pt = x0_scaled_pt * colscale_pt
  282. x0_scaled_ref = x0_scaled_ref * colscale_ref
  283. model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
  284. torch.nn.init.normal_(model_pt.weight)
  285. if not is_rms_norm:
  286. torch.nn.init.normal_(model_pt.bias)
  287. model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
  288. model = our_layer_norm_cls(
  289. hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
  290. )
  291. with torch.no_grad():
  292. model.weight.copy_(model_pt.weight)
  293. model_ref.weight.copy_(model_pt.weight)
  294. if not is_rms_norm:
  295. model.bias.copy_(model_pt.bias)
  296. model_ref.bias.copy_(model_pt.bias)
  297. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  298. out, residual, dmask = our_layer_norm_func(
  299. x0,
  300. res,
  301. model.weight,
  302. model.bias,
  303. model.p,
  304. model.eps,
  305. rowscale=rowscale,
  306. layerscale=colscale,
  307. prenorm=True,
  308. residual_in_fp32=residual_in_fp32,
  309. return_dropout_mask=True,
  310. )
  311. print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
  312. if has_residual:
  313. residual_pt = (
  314. (x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()
  315. ).to(dtype=residual_dtype)
  316. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
  317. else:
  318. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(
  319. dtype=residual_dtype
  320. )
  321. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
  322. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
  323. out_ref = model_ref(residual_ref)
  324. assert out.dtype == input_dtype
  325. assert residual.dtype == residual_dtype
  326. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  327. assert (residual - residual_ref).abs().max() <= 4 * (
  328. residual_pt - residual_ref
  329. ).abs().max() + 1e-4
  330. g = torch.randn_like(out) / batch_size
  331. (out_pt * F.sigmoid(residual_pt)).backward(g)
  332. (out * F.sigmoid(residual)).backward(g)
  333. (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
  334. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  335. if has_residual:
  336. assert (res.grad - res_ref.grad).abs().max() <= 4 * (
  337. res_pt.grad - res_ref.grad
  338. ).abs().max() + 1e-4
  339. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
  340. model_pt.weight.grad - model_ref.weight.grad
  341. ).abs().max() + 2e-4
  342. if not is_rms_norm:
  343. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
  344. model_pt.bias.grad - model_ref.bias.grad
  345. ).abs().max() + 2e-4
  346. if has_colscale:
  347. assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
  348. colscale_pt.grad - colscale_ref.grad
  349. ).abs().max() + 2e-4
  350. @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
  351. @pytest.mark.parametrize(
  352. "input_dtype,residual_dtype",
  353. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  354. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  355. )
  356. @pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
  357. def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
  358. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  359. pytest.skip() # Not supported
  360. device = "cuda"
  361. # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
  362. rtol, atol = (1e-3, 1e-4)
  363. dropout_p = 0.37
  364. # set seed
  365. torch.random.manual_seed(0)
  366. batch_size = 32
  367. seqlen = 512
  368. x0_pt = torch.randn(
  369. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  370. )
  371. x0 = x0_pt.detach().clone().requires_grad_()
  372. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  373. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  374. res = res_pt.detach().clone().requires_grad_()
  375. res_ref = res_pt.detach().clone().float().requires_grad_()
  376. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  377. torch.nn.init.normal_(model_pt.weight)
  378. torch.nn.init.normal_(model_pt.bias)
  379. model = DropoutAddLayerNorm(
  380. hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
  381. )
  382. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  383. with torch.no_grad():
  384. model.weight.copy_(model_pt.weight)
  385. model.bias.copy_(model_pt.bias)
  386. model_ref.weight.copy_(model_pt.weight)
  387. model_ref.bias.copy_(model_pt.bias)
  388. model_pt.eval()
  389. model.eval()
  390. model_ref.eval()
  391. out, residual = model(x0, res)
  392. residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
  393. residual_ref = x0_ref + res_ref
  394. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
  395. out_ref = model_ref(residual_ref)
  396. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  397. assert (residual - residual_ref).abs().max() <= 4 * (
  398. residual_pt - residual_ref
  399. ).abs().max() + 1e-4
  400. @pytest.mark.parametrize("has_colscale", [True, False])
  401. @pytest.mark.parametrize("has_residual", [True, False])
  402. @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
  403. @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
  404. @pytest.mark.parametrize(
  405. "input_dtype,residual_dtype",
  406. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  407. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  408. )
  409. # @pytest.mark.parametrize('has_colscale', [True])
  410. # @pytest.mark.parametrize('has_residual', [True])
  411. # @pytest.mark.parametrize('dropout_p', [0.0])
  412. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  413. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
  414. @pytest.mark.parametrize(
  415. "hidden_size",
  416. [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
  417. )
  418. # @pytest.mark.parametrize('hidden_size', [256])
  419. def test_dropout_layer_norm_subset_training(
  420. hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale
  421. ):
  422. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  423. pytest.skip() # Not supported
  424. device = "cuda"
  425. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  426. rtol, atol = (1e-3, 2e-4)
  427. # set seed
  428. torch.random.manual_seed(0)
  429. batch_size = 8
  430. seqlen = 512
  431. drop_path_rate = 0.4
  432. drop_path_scale = 1 / (1 - drop_path_rate)
  433. def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
  434. # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
  435. mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
  436. numrows = (mask_batch).sum().item() * seqlen
  437. mask_batch = mask_batch.to(device=device, non_blocking=True)
  438. mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen)
  439. subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(
  440. ~mask_batch_seqlen, 0
  441. )
  442. return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size)
  443. x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(
  444. batch_size, seqlen, drop_path_rate, device
  445. )
  446. out_mask_batch, out_numrows, out_subset = generate_droppath_masks(
  447. batch_size, seqlen, drop_path_rate, device
  448. )
  449. x0_pt = torch.randn(
  450. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  451. )
  452. x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
  453. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  454. if has_colscale:
  455. colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  456. colscale_pt = colscale.detach().clone().requires_grad_()
  457. colscale_ref = colscale.detach().clone().float().requires_grad_()
  458. else:
  459. colscale = None
  460. if has_residual:
  461. res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
  462. res = res_pt.detach().clone().requires_grad_()
  463. res_ref = res_pt.detach().clone().float().requires_grad_()
  464. else:
  465. res = None
  466. if has_colscale:
  467. x0_scaled_pt = x0_pt * colscale_pt
  468. x0_scaled_ref = x0_ref * colscale_ref
  469. else:
  470. x0_scaled_pt = x0_pt
  471. x0_scaled_ref = x0_ref
  472. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  473. torch.nn.init.normal_(model_pt.weight)
  474. torch.nn.init.normal_(model_pt.bias)
  475. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  476. model = DropoutAddLayerNorm(
  477. hidden_size, prenorm=False, p=dropout_p, device=device, dtype=weight_dtype
  478. )
  479. with torch.no_grad():
  480. model.weight.copy_(model_pt.weight)
  481. model.bias.copy_(model_pt.bias)
  482. model_ref.weight.copy_(model_pt.weight)
  483. model_ref.bias.copy_(model_pt.bias)
  484. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  485. out, dmask = dropout_add_layer_norm_subset(
  486. x0,
  487. res,
  488. model.weight,
  489. model.bias,
  490. model.p,
  491. model.eps,
  492. layerscale=colscale,
  493. x0_subset=x0_subset,
  494. out_subset=out_subset,
  495. rowscale_const=drop_path_scale,
  496. out_numrows=out_numrows,
  497. prenorm=False,
  498. residual_in_fp32=residual_in_fp32,
  499. return_dropout_mask=True,
  500. )
  501. print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
  502. x0_scaled_pt = (
  503. x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
  504. * drop_path_scale
  505. )
  506. x0_scaled_ref = (
  507. x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
  508. * drop_path_scale
  509. )
  510. dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
  511. dmask_expanded[x0_mask_batch] = dmask
  512. if has_residual:
  513. residual_pt = (
  514. (x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()
  515. ).to(dtype=residual_dtype)
  516. residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
  517. else:
  518. residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(
  519. dtype=residual_dtype
  520. )
  521. residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
  522. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
  523. out_ref = model_ref(residual_ref)[out_mask_batch]
  524. assert out.dtype == input_dtype
  525. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  526. g = torch.randn_like(out) / batch_size
  527. out_pt.backward(g)
  528. out.backward(g)
  529. out_ref.backward(g)
  530. assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[
  531. x0_mask_batch
  532. ].abs().max() + 1e-4
  533. if has_residual:
  534. assert (res.grad - res_ref.grad).abs().max() <= 4 * (
  535. res_pt.grad - res_ref.grad
  536. ).abs().max() + 1e-4
  537. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
  538. model_pt.weight.grad - model_ref.weight.grad
  539. ).abs().max() + 2e-4
  540. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
  541. model_pt.bias.grad - model_ref.bias.grad
  542. ).abs().max() + 2e-4
  543. if has_colscale:
  544. assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
  545. colscale_pt.grad - colscale_ref.grad
  546. ).abs().max() + 2e-4
  547. @pytest.mark.parametrize("has_colscale", [True, False])
  548. @pytest.mark.parametrize("has_residual", [True, False])
  549. @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
  550. @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
  551. @pytest.mark.parametrize(
  552. "input_dtype,residual_dtype",
  553. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  554. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  555. )
  556. # @pytest.mark.parametrize('has_colscale', [True])
  557. # @pytest.mark.parametrize('has_residual', [True])
  558. # @pytest.mark.parametrize('dropout_p', [0.0])
  559. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  560. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
  561. @pytest.mark.parametrize(
  562. "hidden_size",
  563. [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
  564. )
  565. # @pytest.mark.parametrize('hidden_size', [256])
  566. def test_dropout_layer_norm_subset_prenorm_training(
  567. hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale
  568. ):
  569. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  570. pytest.skip() # Not supported
  571. device = "cuda"
  572. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  573. rtol, atol = (1e-3, 2e-4)
  574. # set seed
  575. torch.random.manual_seed(0)
  576. batch_size = 8
  577. seqlen = 512
  578. drop_path_rate = 0.4
  579. drop_path_scale = 1 / (1 - drop_path_rate)
  580. def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
  581. # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
  582. mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
  583. numrows = (mask_batch).sum().item() * seqlen
  584. mask_batch = mask_batch.to(device=device, non_blocking=True)
  585. mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen)
  586. subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(
  587. ~mask_batch_seqlen, 0
  588. )
  589. return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size)
  590. x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(
  591. batch_size, seqlen, drop_path_rate, device
  592. )
  593. out_mask_batch, out_numrows, out_subset = generate_droppath_masks(
  594. batch_size, seqlen, drop_path_rate, device
  595. )
  596. x0_pt = torch.randn(
  597. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  598. )
  599. x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
  600. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  601. if has_colscale:
  602. colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  603. colscale_pt = colscale.detach().clone().requires_grad_()
  604. colscale_ref = colscale.detach().clone().float().requires_grad_()
  605. else:
  606. colscale = None
  607. if has_residual:
  608. res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
  609. res = res_pt.detach().clone().requires_grad_()
  610. res_ref = res_pt.detach().clone().float().requires_grad_()
  611. else:
  612. res = None
  613. if has_colscale:
  614. x0_scaled_pt = x0_pt * colscale_pt
  615. x0_scaled_ref = x0_ref * colscale_ref
  616. else:
  617. x0_scaled_pt = x0_pt
  618. x0_scaled_ref = x0_ref
  619. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  620. torch.nn.init.normal_(model_pt.weight)
  621. torch.nn.init.normal_(model_pt.bias)
  622. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  623. model = DropoutAddLayerNorm(
  624. hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
  625. )
  626. with torch.no_grad():
  627. model.weight.copy_(model_pt.weight)
  628. model.bias.copy_(model_pt.bias)
  629. model_ref.weight.copy_(model_pt.weight)
  630. model_ref.bias.copy_(model_pt.bias)
  631. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  632. out, residual, dmask = dropout_add_layer_norm_subset(
  633. x0,
  634. res,
  635. model.weight,
  636. model.bias,
  637. model.p,
  638. model.eps,
  639. layerscale=colscale,
  640. x0_subset=x0_subset,
  641. out_subset=out_subset,
  642. rowscale_const=drop_path_scale,
  643. out_numrows=out_numrows,
  644. prenorm=True,
  645. residual_in_fp32=residual_in_fp32,
  646. return_dropout_mask=True,
  647. )
  648. print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
  649. x0_scaled_pt = (
  650. x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
  651. * drop_path_scale
  652. )
  653. x0_scaled_ref = (
  654. x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
  655. * drop_path_scale
  656. )
  657. dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
  658. dmask_expanded[x0_mask_batch] = dmask
  659. if has_residual:
  660. residual_pt = (
  661. (x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()
  662. ).to(dtype=residual_dtype)
  663. residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
  664. else:
  665. residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(
  666. dtype=residual_dtype
  667. )
  668. residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
  669. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
  670. out_ref = model_ref(residual_ref)[out_mask_batch]
  671. assert out.dtype == input_dtype
  672. assert residual.dtype == residual_dtype
  673. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  674. assert (residual - residual_ref).abs().max() <= 4 * (
  675. residual_pt - residual_ref
  676. ).abs().max() + 1e-4
  677. g = torch.randn_like(out) / batch_size
  678. (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(
  679. g
  680. )
  681. (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
  682. (
  683. out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype))
  684. + residual_ref.mean(0, keepdim=True)
  685. ).backward(g)
  686. assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[
  687. x0_mask_batch
  688. ].abs().max() + 1e-4
  689. if has_residual:
  690. assert (res.grad - res_ref.grad).abs().max() <= 4 * (
  691. res_pt.grad - res_ref.grad
  692. ).abs().max() + 1e-4
  693. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
  694. model_pt.weight.grad - model_ref.weight.grad
  695. ).abs().max() + 2e-4
  696. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
  697. model_pt.bias.grad - model_ref.bias.grad
  698. ).abs().max() + 2e-4
  699. if has_colscale:
  700. assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
  701. colscale_pt.grad - colscale_ref.grad
  702. ).abs().max() + 2e-4
  703. @pytest.mark.parametrize("is_rms_norm", [False, True])
  704. # @pytest.mark.parametrize('is_rms_norm', [False])
  705. @pytest.mark.parametrize("tied_norm", [False, True])
  706. # @pytest.mark.parametrize('tied_norm', [False])
  707. @pytest.mark.parametrize("has_residual", [True, False])
  708. # @pytest.mark.parametrize('has_residual', [False])
  709. @pytest.mark.parametrize("has_x1", [True, False])
  710. # @pytest.mark.parametrize('has_x1', [True])
  711. @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
  712. # @pytest.mark.parametrize('dropout_p', [0.0])
  713. @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
  714. # @pytest.mark.parametrize('weight_dtype', [torch.float16])
  715. @pytest.mark.parametrize(
  716. "input_dtype,residual_dtype",
  717. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  718. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  719. )
  720. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
  721. @pytest.mark.parametrize(
  722. "hidden_size",
  723. [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
  724. )
  725. # @pytest.mark.parametrize('hidden_size', [256])
  726. def test_dropout_layer_norm_parallel_residual_training(
  727. hidden_size,
  728. input_dtype,
  729. residual_dtype,
  730. weight_dtype,
  731. dropout_p,
  732. has_x1,
  733. has_residual,
  734. tied_norm,
  735. is_rms_norm,
  736. ):
  737. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  738. pytest.skip() # Not supported
  739. if is_rms_norm and fused_rms_norm_affine is None:
  740. pytest.skip() # We need Apex's FusedRMSNorm to test
  741. our_layer_norm_func = (
  742. dropout_add_layer_norm_parallel_residual
  743. if not is_rms_norm
  744. else dropout_add_rms_norm_parallel_residual
  745. )
  746. device = "cuda"
  747. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  748. rtol, atol = (1e-3, 1e-4)
  749. # set seed
  750. torch.random.manual_seed(0)
  751. batch_size = 8
  752. seqlen = 512
  753. x0_pt = torch.randn(
  754. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  755. )
  756. x0 = x0_pt.detach().clone().requires_grad_()
  757. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  758. if has_x1:
  759. x1_pt = torch.randn(
  760. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  761. )
  762. x1 = x1_pt.detach().clone().requires_grad_()
  763. x1_ref = x1_pt.detach().clone().float().requires_grad_()
  764. else:
  765. x1 = None
  766. if has_residual:
  767. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  768. res = res_pt.detach().clone().requires_grad_()
  769. res_ref = res_pt.detach().clone().float().requires_grad_()
  770. else:
  771. res = None
  772. weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  773. bias0 = (
  774. torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  775. if not is_rms_norm
  776. else None
  777. )
  778. weight0_pt = weight0.detach().clone().requires_grad_()
  779. weight0_ref = weight0.detach().clone().float().requires_grad_()
  780. bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
  781. bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
  782. if not tied_norm:
  783. weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  784. bias1 = (
  785. torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  786. if not is_rms_norm
  787. else None
  788. )
  789. weight1_pt = weight1.detach().clone().requires_grad_()
  790. weight1_ref = weight1.detach().clone().float().requires_grad_()
  791. bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
  792. bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
  793. else:
  794. weight1, bias1 = None, None
  795. epsilon = 1e-5
  796. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  797. out0, out1, dmask0, dmask1 = our_layer_norm_func(
  798. x0,
  799. x1,
  800. res,
  801. weight0,
  802. bias0,
  803. weight1,
  804. bias1,
  805. dropout_p,
  806. epsilon,
  807. residual_in_fp32=residual_in_fp32,
  808. return_dropout_mask=True,
  809. )
  810. assert out0.dtype == input_dtype
  811. if not tied_norm:
  812. assert out1.dtype == input_dtype
  813. print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}")
  814. if has_residual:
  815. if has_x1:
  816. residual_pt = (
  817. (x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  818. + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
  819. + res_pt.float()
  820. ).to(dtype=residual_dtype)
  821. residual_ref = (
  822. (x0_ref * dmask0.float()) / (1 - dropout_p)
  823. + (x1_ref * dmask1.float()) / (1 - dropout_p)
  824. ) + res_ref
  825. else:
  826. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(
  827. dtype=residual_dtype
  828. )
  829. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
  830. else:
  831. if has_x1:
  832. residual_pt = (
  833. (x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  834. + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
  835. ).to(dtype=residual_dtype)
  836. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (
  837. x1_ref * dmask1.float()
  838. ) / (1 - dropout_p)
  839. else:
  840. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(
  841. dtype=residual_dtype
  842. )
  843. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
  844. if not is_rms_norm:
  845. out0_pt = F.layer_norm(
  846. residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon
  847. ).to(dtype=input_dtype)
  848. out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
  849. if not tied_norm:
  850. out1_pt = F.layer_norm(
  851. residual_pt.to(dtype=weight_dtype),
  852. (hidden_size,),
  853. weight1_pt,
  854. bias1_pt,
  855. eps=epsilon,
  856. ).to(dtype=input_dtype)
  857. out1_ref = F.layer_norm(
  858. residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon
  859. )
  860. else:
  861. out0_pt = fused_rms_norm_affine(
  862. residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon
  863. ).to(dtype=input_dtype)
  864. out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
  865. if not tied_norm:
  866. out1_pt = fused_rms_norm_affine(
  867. residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon
  868. ).to(dtype=input_dtype)
  869. out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
  870. assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
  871. if not tied_norm:
  872. assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
  873. g0 = torch.randn_like(out0) / batch_size
  874. if tied_norm:
  875. out0.backward(g0)
  876. out0_pt.backward(g0)
  877. out0_ref.backward(g0)
  878. else:
  879. g1 = torch.randn_like(out1) / batch_size
  880. (out0 * g0 + out1 * g1).sum().backward()
  881. (out0_pt * g0 + out1_pt * g1).sum().backward()
  882. (out0_ref * g0 + out1_ref * g1).sum().backward()
  883. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  884. if has_x1:
  885. assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (
  886. x1_pt.grad - x1_ref.grad
  887. ).abs().max() + 1e-4
  888. if has_residual:
  889. assert (res.grad - res_ref.grad).abs().max() <= 4 * (
  890. res_pt.grad - res_ref.grad
  891. ).abs().max() + 1e-4
  892. assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (
  893. weight0_pt.grad - weight0_ref.grad
  894. ).abs().max() + 3e-5
  895. if not is_rms_norm:
  896. assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (
  897. bias0_pt.grad - bias0_ref.grad
  898. ).abs().max() + 3e-5
  899. if not tied_norm:
  900. assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (
  901. weight1_pt.grad - weight1_ref.grad
  902. ).abs().max() + 3e-5
  903. if not is_rms_norm:
  904. assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (
  905. bias1_pt.grad - bias1_ref.grad
  906. ).abs().max() + 3e-5
  907. @pytest.mark.parametrize("is_rms_norm", [False, True])
  908. # @pytest.mark.parametrize('is_rms_norm', [False])
  909. @pytest.mark.parametrize("tied_norm", [False, True])
  910. # @pytest.mark.parametrize('tied_norm', [False])
  911. @pytest.mark.parametrize("has_residual", [True, False])
  912. # @pytest.mark.parametrize('has_residual', [False])
  913. @pytest.mark.parametrize("has_x1", [True, False])
  914. # @pytest.mark.parametrize('has_x1', [True])
  915. @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
  916. # @pytest.mark.parametrize('dropout_p', [0.0])
  917. @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
  918. # @pytest.mark.parametrize('weight_dtype', [torch.float16])
  919. @pytest.mark.parametrize(
  920. "input_dtype,residual_dtype",
  921. [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
  922. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
  923. )
  924. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
  925. @pytest.mark.parametrize(
  926. "hidden_size",
  927. [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
  928. )
  929. # @pytest.mark.parametrize('hidden_size', [256])
  930. def test_dropout_layer_norm_parallel_residual_prenorm_training(
  931. hidden_size,
  932. input_dtype,
  933. residual_dtype,
  934. weight_dtype,
  935. dropout_p,
  936. has_x1,
  937. has_residual,
  938. tied_norm,
  939. is_rms_norm,
  940. ):
  941. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  942. pytest.skip() # Not supported
  943. if is_rms_norm and fused_rms_norm_affine is None:
  944. pytest.skip() # We need Apex's FusedRMSNorm to test
  945. our_layer_norm_func = (
  946. dropout_add_layer_norm_parallel_residual
  947. if not is_rms_norm
  948. else dropout_add_rms_norm_parallel_residual
  949. )
  950. device = "cuda"
  951. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  952. rtol, atol = (1e-3, 1e-4)
  953. # set seed
  954. torch.random.manual_seed(0)
  955. batch_size = 8
  956. seqlen = 512
  957. x0_pt = torch.randn(
  958. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  959. )
  960. x0 = x0_pt.detach().clone().requires_grad_()
  961. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  962. if has_x1:
  963. x1_pt = torch.randn(
  964. batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
  965. )
  966. x1 = x1_pt.detach().clone().requires_grad_()
  967. x1_ref = x1_pt.detach().clone().float().requires_grad_()
  968. else:
  969. x1 = None
  970. if has_residual:
  971. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  972. res = res_pt.detach().clone().requires_grad_()
  973. res_ref = res_pt.detach().clone().float().requires_grad_()
  974. else:
  975. res = None
  976. weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  977. bias0 = (
  978. torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  979. if not is_rms_norm
  980. else None
  981. )
  982. weight0_pt = weight0.detach().clone().requires_grad_()
  983. weight0_ref = weight0.detach().clone().float().requires_grad_()
  984. bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
  985. bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
  986. if not tied_norm:
  987. weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  988. bias1 = (
  989. torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  990. if not is_rms_norm
  991. else None
  992. )
  993. weight1_pt = weight1.detach().clone().requires_grad_()
  994. weight1_ref = weight1.detach().clone().float().requires_grad_()
  995. bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
  996. bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
  997. else:
  998. weight1, bias1 = None, None
  999. epsilon = 1e-5
  1000. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  1001. out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(
  1002. x0,
  1003. x1,
  1004. res,
  1005. weight0,
  1006. bias0,
  1007. weight1,
  1008. bias1,
  1009. dropout_p,
  1010. epsilon,
  1011. prenorm=True,
  1012. residual_in_fp32=residual_in_fp32,
  1013. return_dropout_mask=True,
  1014. )
  1015. assert out0.dtype == input_dtype
  1016. if not tied_norm:
  1017. assert out1.dtype == input_dtype
  1018. print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}")
  1019. if has_residual:
  1020. if has_x1:
  1021. residual_pt = (
  1022. (x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  1023. + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
  1024. + res_pt.float()
  1025. ).to(dtype=residual_dtype)
  1026. residual_ref = (
  1027. (x0_ref * dmask0.float()) / (1 - dropout_p)
  1028. + (x1_ref * dmask1.float()) / (1 - dropout_p)
  1029. ) + res_ref
  1030. else:
  1031. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(
  1032. dtype=residual_dtype
  1033. )
  1034. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
  1035. else:
  1036. if has_x1:
  1037. residual_pt = (
  1038. (x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  1039. + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
  1040. ).to(dtype=residual_dtype)
  1041. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (
  1042. x1_ref * dmask1.float()
  1043. ) / (1 - dropout_p)
  1044. else:
  1045. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(
  1046. dtype=residual_dtype
  1047. )
  1048. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
  1049. if not is_rms_norm:
  1050. out0_pt = F.layer_norm(
  1051. residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon
  1052. ).to(dtype=input_dtype)
  1053. out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
  1054. if not tied_norm:
  1055. out1_pt = F.layer_norm(
  1056. residual_pt.to(dtype=weight_dtype),
  1057. (hidden_size,),
  1058. weight1_pt,
  1059. bias1_pt,
  1060. eps=epsilon,
  1061. ).to(dtype=input_dtype)
  1062. out1_ref = F.layer_norm(
  1063. residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon
  1064. )
  1065. else:
  1066. out0_pt = fused_rms_norm_affine(
  1067. residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon
  1068. ).to(dtype=input_dtype)
  1069. out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
  1070. if not tied_norm:
  1071. out1_pt = fused_rms_norm_affine(
  1072. residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon
  1073. ).to(dtype=input_dtype)
  1074. out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
  1075. assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
  1076. if not tied_norm:
  1077. assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
  1078. assert (residual - residual_ref).abs().max() <= 4 * (
  1079. residual_pt - residual_ref
  1080. ).abs().max() + 1e-4
  1081. g0 = torch.randn_like(out0) / batch_size
  1082. if tied_norm:
  1083. (out0 * F.sigmoid(residual)).backward(g0)
  1084. (out0_pt * F.sigmoid(residual_pt)).backward(g0)
  1085. (out0_ref * F.sigmoid(residual_ref)).backward(g0)
  1086. else:
  1087. g1 = torch.randn_like(out1) / batch_size
  1088. (out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward()
  1089. (out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward()
  1090. (out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()
  1091. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  1092. if has_x1:
  1093. assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (
  1094. x1_pt.grad - x1_ref.grad
  1095. ).abs().max() + 1e-4
  1096. if has_residual:
  1097. assert (res.grad - res_ref.grad).abs().max() <= 4 * (
  1098. res_pt.grad - res_ref.grad
  1099. ).abs().max() + 1e-4
  1100. assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (
  1101. weight0_pt.grad - weight0_ref.grad
  1102. ).abs().max() + 3e-5
  1103. if not is_rms_norm:
  1104. assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (
  1105. bias0_pt.grad - bias0_ref.grad
  1106. ).abs().max() + 3e-5
  1107. if not tied_norm:
  1108. assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (
  1109. weight1_pt.grad - weight1_ref.grad
  1110. ).abs().max() + 3e-5
  1111. if not is_rms_norm:
  1112. assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (
  1113. bias1_pt.grad - bias1_ref.grad
  1114. ).abs().max() + 3e-5
  1115. def test_dropout_layer_norm_randomness():
  1116. hidden_size = 256
  1117. dtype = torch.float32
  1118. dropout_p = 0.1
  1119. device = "cuda"
  1120. # set seed
  1121. torch.random.manual_seed(0)
  1122. batch_size = 8
  1123. seqlen = 512
  1124. x0 = torch.randn(
  1125. batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True
  1126. )
  1127. res = torch.randn_like(x0, dtype=dtype, requires_grad=True)
  1128. model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype)
  1129. torch.random.manual_seed(42)
  1130. _, dmask0 = dropout_add_layer_norm(
  1131. x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
  1132. )
  1133. # Subsequent call should have a different dropout mask
  1134. _, dmask1 = dropout_add_layer_norm(
  1135. x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
  1136. )
  1137. torch.random.manual_seed(42)
  1138. # Resetting the seed, should get the same dropout mask
  1139. _, dmask2 = dropout_add_layer_norm(
  1140. x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
  1141. )
  1142. assert not torch.equal(dmask0, dmask1)
  1143. assert torch.equal(dmask0, dmask2)