test_dropout_layer_norm.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. import pytest
  5. from einops import rearrange, repeat
  6. from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
  7. from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
  8. from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm
  9. from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset
  10. from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
  11. from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
  12. try:
  13. from apex.normalization import FusedRMSNorm
  14. from apex.normalization.fused_layer_norm import fused_rms_norm_affine
  15. except:
  16. FusedRMSNorm, fused_rms_norm_affine = None, None
  17. is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
  18. @pytest.mark.parametrize('is_rms_norm', [False, True])
  19. @pytest.mark.parametrize('has_colscale', [True, False])
  20. # @pytest.mark.parametrize('has_colscale', [False])
  21. @pytest.mark.parametrize('has_rowscale', [True, False])
  22. # @pytest.mark.parametrize('has_rowscale', [True])
  23. @pytest.mark.parametrize('has_residual', [True, False])
  24. # @pytest.mark.parametrize('has_residual', [False])
  25. @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
  26. # @pytest.mark.parametrize('dropout_p', [0.0])
  27. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  28. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  29. @pytest.mark.parametrize('input_dtype,residual_dtype',
  30. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  31. (torch.float32, torch.float32)]
  32. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  33. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
  34. @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
  35. # @pytest.mark.parametrize('hidden_size', [256])
  36. def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
  37. dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm):
  38. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  39. pytest.skip() # Not supported
  40. if is_rms_norm and FusedRMSNorm is None:
  41. pytest.skip() # We need Apex's FusedRMSNorm to test
  42. layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
  43. our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
  44. our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
  45. device = 'cuda'
  46. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  47. rtol, atol = (1e-3, 1e-4)
  48. # set seed
  49. torch.random.manual_seed(0)
  50. batch_size = 8
  51. seqlen = 512
  52. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  53. requires_grad=True)
  54. x0 = x0_pt.detach().clone().requires_grad_()
  55. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  56. if has_colscale:
  57. colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  58. colscale_pt = colscale.detach().clone().requires_grad_()
  59. colscale_ref = colscale.detach().clone().float().requires_grad_()
  60. else:
  61. colscale = None
  62. if has_residual:
  63. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  64. res = res_pt.detach().clone().requires_grad_()
  65. res_ref = res_pt.detach().clone().float().requires_grad_()
  66. else:
  67. res = None
  68. if has_rowscale:
  69. rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
  70. survival_rate = 0.87
  71. rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
  72. x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
  73. x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
  74. else:
  75. rowscale = None
  76. x0_scaled_pt = x0_pt
  77. x0_scaled_ref = x0_ref
  78. if has_colscale:
  79. x0_scaled_pt = x0_scaled_pt * colscale_pt
  80. x0_scaled_ref = x0_scaled_ref * colscale_ref
  81. model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
  82. torch.nn.init.normal_(model_pt.weight)
  83. if not is_rms_norm:
  84. torch.nn.init.normal_(model_pt.bias)
  85. model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
  86. model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
  87. with torch.no_grad():
  88. model.weight.copy_(model_pt.weight)
  89. model_ref.weight.copy_(model_pt.weight)
  90. if not is_rms_norm:
  91. model.bias.copy_(model_pt.bias)
  92. model_ref.bias.copy_(model_pt.bias)
  93. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  94. out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
  95. model.epsilon, rowscale=rowscale, layerscale=colscale,
  96. residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
  97. assert out.dtype == input_dtype
  98. print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
  99. if has_residual:
  100. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
  101. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
  102. else:
  103. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  104. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
  105. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
  106. out_ref = model_ref(residual_ref)
  107. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  108. g = torch.randn_like(out) / batch_size
  109. out_pt.backward(g)
  110. out.backward(g)
  111. out_ref.backward(g)
  112. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  113. if has_residual:
  114. assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
  115. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
  116. if not is_rms_norm:
  117. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
  118. if has_colscale:
  119. assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
  120. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  121. @pytest.mark.parametrize('input_dtype,residual_dtype',
  122. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  123. (torch.float32, torch.float32)]
  124. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  125. @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
  126. def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
  127. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  128. pytest.skip() # Not supported
  129. device = 'cuda'
  130. # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
  131. rtol, atol = (1e-3, 1e-4)
  132. dropout_p = 0.37
  133. # set seed
  134. torch.random.manual_seed(0)
  135. batch_size = 32
  136. seqlen = 512
  137. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  138. requires_grad=True)
  139. x0 = x0_pt.detach().clone().requires_grad_()
  140. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  141. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  142. res = res_pt.detach().clone().requires_grad_()
  143. res_ref = res_pt.detach().clone().float().requires_grad_()
  144. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  145. torch.nn.init.normal_(model_pt.weight)
  146. torch.nn.init.normal_(model_pt.bias)
  147. model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
  148. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  149. with torch.no_grad():
  150. model.weight.copy_(model_pt.weight)
  151. model.bias.copy_(model_pt.bias)
  152. model_ref.weight.copy_(model_pt.weight)
  153. model_ref.bias.copy_(model_pt.bias)
  154. model_pt.eval()
  155. model.eval()
  156. model_ref.eval()
  157. out = model(x0, res)
  158. residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
  159. residual_ref = x0_ref + res_ref
  160. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
  161. out_ref = model_ref(residual_ref)
  162. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  163. @pytest.mark.parametrize('is_rms_norm', [False, True])
  164. @pytest.mark.parametrize('has_colscale', [True, False])
  165. @pytest.mark.parametrize('has_rowscale', [True, False])
  166. @pytest.mark.parametrize('has_residual', [True, False])
  167. @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
  168. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  169. @pytest.mark.parametrize('input_dtype,residual_dtype',
  170. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  171. (torch.float32, torch.float32)]
  172. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  173. # @pytest.mark.parametrize('has_colscale', [True])
  174. # @pytest.mark.parametrize('has_rowscale', [False])
  175. # @pytest.mark.parametrize('has_residual', [True])
  176. # @pytest.mark.parametrize('dropout_p', [0.0])
  177. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  178. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
  179. @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
  180. # @pytest.mark.parametrize('hidden_size', [256])
  181. def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
  182. dropout_p, has_residual, has_rowscale, has_colscale,
  183. is_rms_norm):
  184. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  185. pytest.skip() # Not supported
  186. if is_rms_norm and FusedRMSNorm is None:
  187. pytest.skip() # We need Apex's FusedRMSNorm to test
  188. layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
  189. our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
  190. our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
  191. device = 'cuda'
  192. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  193. rtol, atol = (1e-3, 2e-4)
  194. # set seed
  195. torch.random.manual_seed(0)
  196. batch_size = 8
  197. seqlen = 512
  198. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  199. requires_grad=True)
  200. x0 = x0_pt.detach().clone().requires_grad_()
  201. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  202. if has_colscale:
  203. colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  204. colscale_pt = colscale.detach().clone().requires_grad_()
  205. colscale_ref = colscale.detach().clone().float().requires_grad_()
  206. else:
  207. colscale = None
  208. if has_residual:
  209. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  210. res = res_pt.detach().clone().requires_grad_()
  211. res_ref = res_pt.detach().clone().float().requires_grad_()
  212. else:
  213. res = None
  214. if has_rowscale:
  215. rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
  216. survival_rate = 0.87
  217. rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
  218. x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
  219. x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
  220. else:
  221. rowscale = None
  222. x0_scaled_pt = x0_pt
  223. x0_scaled_ref = x0_ref
  224. if has_colscale:
  225. x0_scaled_pt = x0_scaled_pt * colscale_pt
  226. x0_scaled_ref = x0_scaled_ref * colscale_ref
  227. model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
  228. torch.nn.init.normal_(model_pt.weight)
  229. if not is_rms_norm:
  230. torch.nn.init.normal_(model_pt.bias)
  231. model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
  232. model = our_layer_norm_cls(hidden_size, prenorm=True, p=dropout_p, device=device,
  233. dtype=weight_dtype)
  234. with torch.no_grad():
  235. model.weight.copy_(model_pt.weight)
  236. model_ref.weight.copy_(model_pt.weight)
  237. if not is_rms_norm:
  238. model.bias.copy_(model_pt.bias)
  239. model_ref.bias.copy_(model_pt.bias)
  240. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  241. out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
  242. model.epsilon, rowscale=rowscale,
  243. layerscale=colscale, prenorm=True,
  244. residual_in_fp32=residual_in_fp32,
  245. return_dropout_mask=True)
  246. print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
  247. if has_residual:
  248. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
  249. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
  250. else:
  251. residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  252. residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
  253. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
  254. out_ref = model_ref(residual_ref)
  255. assert out.dtype == input_dtype
  256. assert residual.dtype == residual_dtype
  257. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  258. assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
  259. g = torch.randn_like(out) / batch_size
  260. (out_pt * F.sigmoid(residual_pt)).backward(g)
  261. (out * F.sigmoid(residual)).backward(g)
  262. (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
  263. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  264. if has_residual:
  265. assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
  266. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
  267. if not is_rms_norm:
  268. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
  269. if has_colscale:
  270. assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
  271. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  272. @pytest.mark.parametrize('input_dtype,residual_dtype',
  273. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  274. (torch.float32, torch.float32)]
  275. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  276. @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
  277. def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
  278. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  279. pytest.skip() # Not supported
  280. device = 'cuda'
  281. # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
  282. rtol, atol = (1e-3, 1e-4)
  283. dropout_p = 0.37
  284. # set seed
  285. torch.random.manual_seed(0)
  286. batch_size = 32
  287. seqlen = 512
  288. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  289. requires_grad=True)
  290. x0 = x0_pt.detach().clone().requires_grad_()
  291. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  292. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  293. res = res_pt.detach().clone().requires_grad_()
  294. res_ref = res_pt.detach().clone().float().requires_grad_()
  295. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  296. torch.nn.init.normal_(model_pt.weight)
  297. torch.nn.init.normal_(model_pt.bias)
  298. model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
  299. dtype=weight_dtype)
  300. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  301. with torch.no_grad():
  302. model.weight.copy_(model_pt.weight)
  303. model.bias.copy_(model_pt.bias)
  304. model_ref.weight.copy_(model_pt.weight)
  305. model_ref.bias.copy_(model_pt.bias)
  306. model_pt.eval()
  307. model.eval()
  308. model_ref.eval()
  309. out, residual = model(x0, res)
  310. residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
  311. residual_ref = x0_ref + res_ref
  312. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
  313. out_ref = model_ref(residual_ref)
  314. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  315. assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
  316. @pytest.mark.parametrize('has_colscale', [True, False])
  317. @pytest.mark.parametrize('has_residual', [True, False])
  318. @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
  319. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  320. @pytest.mark.parametrize('input_dtype,residual_dtype',
  321. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  322. (torch.float32, torch.float32)]
  323. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  324. # @pytest.mark.parametrize('has_colscale', [True])
  325. # @pytest.mark.parametrize('has_residual', [True])
  326. # @pytest.mark.parametrize('dropout_p', [0.0])
  327. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  328. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
  329. @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
  330. # @pytest.mark.parametrize('hidden_size', [256])
  331. def test_dropout_layer_norm_subset_training(
  332. hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
  333. has_residual, has_colscale):
  334. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  335. pytest.skip() # Not supported
  336. device = 'cuda'
  337. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  338. rtol, atol = (1e-3, 2e-4)
  339. # set seed
  340. torch.random.manual_seed(0)
  341. batch_size = 8
  342. seqlen = 512
  343. drop_path_rate = 0.4
  344. drop_path_scale = 1 / (1 - drop_path_rate)
  345. def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
  346. # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
  347. mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
  348. numrows = (mask_batch).sum().item() * seqlen
  349. mask_batch = mask_batch.to(device=device, non_blocking=True)
  350. mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
  351. subset = torch.cumsum(mask_batch_seqlen, dim=0,
  352. dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
  353. return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)
  354. x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
  355. drop_path_rate, device)
  356. out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
  357. drop_path_rate, device)
  358. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  359. requires_grad=True)
  360. x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
  361. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  362. if has_colscale:
  363. colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  364. colscale_pt = colscale.detach().clone().requires_grad_()
  365. colscale_ref = colscale.detach().clone().float().requires_grad_()
  366. else:
  367. colscale = None
  368. if has_residual:
  369. res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
  370. res = res_pt.detach().clone().requires_grad_()
  371. res_ref = res_pt.detach().clone().float().requires_grad_()
  372. else:
  373. res = None
  374. if has_colscale:
  375. x0_scaled_pt = x0_pt * colscale_pt
  376. x0_scaled_ref = x0_ref * colscale_ref
  377. else:
  378. x0_scaled_pt = x0_pt
  379. x0_scaled_ref = x0_ref
  380. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  381. torch.nn.init.normal_(model_pt.weight)
  382. torch.nn.init.normal_(model_pt.bias)
  383. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  384. model = DropoutAddLayerNorm(hidden_size, prenorm=False, p=dropout_p, device=device,
  385. dtype=weight_dtype)
  386. with torch.no_grad():
  387. model.weight.copy_(model_pt.weight)
  388. model.bias.copy_(model_pt.bias)
  389. model_ref.weight.copy_(model_pt.weight)
  390. model_ref.bias.copy_(model_pt.bias)
  391. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  392. out, dmask = dropout_add_layer_norm_subset(
  393. x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
  394. x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
  395. out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
  396. return_dropout_mask=True)
  397. print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
  398. x0_scaled_pt = x0_scaled_pt.masked_fill(
  399. repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
  400. ) * drop_path_scale
  401. x0_scaled_ref = x0_scaled_ref.masked_fill(
  402. repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
  403. ) * drop_path_scale
  404. dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
  405. dmask_expanded[x0_mask_batch] = dmask
  406. if has_residual:
  407. residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
  408. residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
  409. else:
  410. residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  411. residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
  412. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
  413. out_ref = model_ref(residual_ref)[out_mask_batch]
  414. assert out.dtype == input_dtype
  415. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  416. g = torch.randn_like(out) / batch_size
  417. out_pt.backward(g)
  418. out.backward(g)
  419. out_ref.backward(g)
  420. assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
  421. if has_residual:
  422. assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
  423. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
  424. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
  425. if has_colscale:
  426. assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
  427. @pytest.mark.parametrize('has_colscale', [True, False])
  428. @pytest.mark.parametrize('has_residual', [True, False])
  429. @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
  430. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  431. @pytest.mark.parametrize('input_dtype,residual_dtype',
  432. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  433. (torch.float32, torch.float32)]
  434. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  435. # @pytest.mark.parametrize('has_colscale', [True])
  436. # @pytest.mark.parametrize('has_residual', [True])
  437. # @pytest.mark.parametrize('dropout_p', [0.0])
  438. # @pytest.mark.parametrize('weight_dtype', [torch.float32])
  439. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
  440. @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
  441. # @pytest.mark.parametrize('hidden_size', [256])
  442. def test_dropout_layer_norm_subset_prenorm_training(
  443. hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
  444. has_residual, has_colscale):
  445. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  446. pytest.skip() # Not supported
  447. device = 'cuda'
  448. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  449. rtol, atol = (1e-3, 2e-4)
  450. # set seed
  451. torch.random.manual_seed(0)
  452. batch_size = 8
  453. seqlen = 512
  454. drop_path_rate = 0.4
  455. drop_path_scale = 1 / (1 - drop_path_rate)
  456. def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
  457. # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
  458. mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
  459. numrows = (mask_batch).sum().item() * seqlen
  460. mask_batch = mask_batch.to(device=device, non_blocking=True)
  461. mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
  462. subset = torch.cumsum(mask_batch_seqlen, dim=0,
  463. dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
  464. return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)
  465. x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
  466. drop_path_rate, device)
  467. out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
  468. drop_path_rate, device)
  469. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  470. requires_grad=True)
  471. x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
  472. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  473. if has_colscale:
  474. colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  475. colscale_pt = colscale.detach().clone().requires_grad_()
  476. colscale_ref = colscale.detach().clone().float().requires_grad_()
  477. else:
  478. colscale = None
  479. if has_residual:
  480. res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
  481. res = res_pt.detach().clone().requires_grad_()
  482. res_ref = res_pt.detach().clone().float().requires_grad_()
  483. else:
  484. res = None
  485. if has_colscale:
  486. x0_scaled_pt = x0_pt * colscale_pt
  487. x0_scaled_ref = x0_ref * colscale_ref
  488. else:
  489. x0_scaled_pt = x0_pt
  490. x0_scaled_ref = x0_ref
  491. model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
  492. torch.nn.init.normal_(model_pt.weight)
  493. torch.nn.init.normal_(model_pt.bias)
  494. model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
  495. model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
  496. dtype=weight_dtype)
  497. with torch.no_grad():
  498. model.weight.copy_(model_pt.weight)
  499. model.bias.copy_(model_pt.bias)
  500. model_ref.weight.copy_(model_pt.weight)
  501. model_ref.bias.copy_(model_pt.bias)
  502. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  503. out, residual, dmask = dropout_add_layer_norm_subset(
  504. x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
  505. x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
  506. out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
  507. return_dropout_mask=True)
  508. print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
  509. x0_scaled_pt = x0_scaled_pt.masked_fill(
  510. repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
  511. ) * drop_path_scale
  512. x0_scaled_ref = x0_scaled_ref.masked_fill(
  513. repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
  514. ) * drop_path_scale
  515. dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
  516. dmask_expanded[x0_mask_batch] = dmask
  517. if has_residual:
  518. residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
  519. residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
  520. else:
  521. residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  522. residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
  523. out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
  524. out_ref = model_ref(residual_ref)[out_mask_batch]
  525. assert out.dtype == input_dtype
  526. assert residual.dtype == residual_dtype
  527. assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
  528. assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
  529. g = torch.randn_like(out) / batch_size
  530. (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(g)
  531. (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
  532. (out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g)
  533. assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
  534. if has_residual:
  535. assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
  536. assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
  537. assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
  538. if has_colscale:
  539. assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
  540. @pytest.mark.parametrize('is_rms_norm', [False, True])
  541. # @pytest.mark.parametrize('is_rms_norm', [False])
  542. @pytest.mark.parametrize('tied_norm', [False, True])
  543. # @pytest.mark.parametrize('tied_norm', [False])
  544. @pytest.mark.parametrize('has_residual', [True, False])
  545. # @pytest.mark.parametrize('has_residual', [False])
  546. @pytest.mark.parametrize('has_x1', [True, False])
  547. # @pytest.mark.parametrize('has_x1', [True])
  548. @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
  549. # @pytest.mark.parametrize('dropout_p', [0.0])
  550. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  551. # @pytest.mark.parametrize('weight_dtype', [torch.float16])
  552. @pytest.mark.parametrize('input_dtype,residual_dtype',
  553. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  554. (torch.float32, torch.float32)]
  555. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  556. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
  557. @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
  558. # @pytest.mark.parametrize('hidden_size', [256])
  559. def test_dropout_layer_norm_parallel_residual_training(
  560. hidden_size, input_dtype, residual_dtype, weight_dtype,
  561. dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
  562. ):
  563. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  564. pytest.skip() # Not supported
  565. if is_rms_norm and fused_rms_norm_affine is None:
  566. pytest.skip() # We need Apex's FusedRMSNorm to test
  567. our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
  568. else dropout_add_rms_norm_parallel_residual)
  569. device = 'cuda'
  570. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  571. rtol, atol = (1e-3, 1e-4)
  572. # set seed
  573. torch.random.manual_seed(0)
  574. batch_size = 8
  575. seqlen = 512
  576. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  577. requires_grad=True)
  578. x0 = x0_pt.detach().clone().requires_grad_()
  579. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  580. if has_x1:
  581. x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  582. requires_grad=True)
  583. x1 = x1_pt.detach().clone().requires_grad_()
  584. x1_ref = x1_pt.detach().clone().float().requires_grad_()
  585. else:
  586. x1 = None
  587. if has_residual:
  588. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  589. res = res_pt.detach().clone().requires_grad_()
  590. res_ref = res_pt.detach().clone().float().requires_grad_()
  591. else:
  592. res = None
  593. weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  594. bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  595. if not is_rms_norm else None)
  596. weight0_pt = weight0.detach().clone().requires_grad_()
  597. weight0_ref = weight0.detach().clone().float().requires_grad_()
  598. bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
  599. bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
  600. if not tied_norm:
  601. weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  602. bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  603. if not is_rms_norm else None)
  604. weight1_pt = weight1.detach().clone().requires_grad_()
  605. weight1_ref = weight1.detach().clone().float().requires_grad_()
  606. bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
  607. bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
  608. else:
  609. weight1, bias1 = None, None
  610. epsilon = 1e-5
  611. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  612. out0, out1, dmask0, dmask1 = our_layer_norm_func(
  613. x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
  614. epsilon, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
  615. )
  616. assert out0.dtype == input_dtype
  617. if not tied_norm:
  618. assert out1.dtype == input_dtype
  619. print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
  620. if has_residual:
  621. if has_x1:
  622. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  623. + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
  624. + res_pt.float()).to(dtype=residual_dtype)
  625. residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
  626. + (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
  627. else:
  628. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  629. + res_pt.float()).to(dtype=residual_dtype)
  630. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
  631. else:
  632. if has_x1:
  633. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  634. + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  635. residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
  636. + (x1_ref * dmask1.float()) / (1 - dropout_p))
  637. else:
  638. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  639. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
  640. if not is_rms_norm:
  641. out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
  642. eps=epsilon).to(dtype=input_dtype)
  643. out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
  644. if not tied_norm:
  645. out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
  646. bias1_pt, eps=epsilon).to(dtype=input_dtype)
  647. out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
  648. else:
  649. out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
  650. eps=epsilon).to(dtype=input_dtype)
  651. out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
  652. if not tied_norm:
  653. out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
  654. (hidden_size,), eps=epsilon).to(dtype=input_dtype)
  655. out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
  656. assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
  657. if not tied_norm:
  658. assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
  659. g0 = torch.randn_like(out0) / batch_size
  660. if tied_norm:
  661. out0.backward(g0)
  662. out0_pt.backward(g0)
  663. out0_ref.backward(g0)
  664. else:
  665. g1 = torch.randn_like(out1) / batch_size
  666. (out0 * g0 + out1 * g1).sum().backward()
  667. (out0_pt * g0 + out1_pt * g1).sum().backward()
  668. (out0_ref * g0 + out1_ref * g1).sum().backward()
  669. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  670. if has_x1:
  671. assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
  672. if has_residual:
  673. assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
  674. assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
  675. if not is_rms_norm:
  676. assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
  677. if not tied_norm:
  678. assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
  679. if not is_rms_norm:
  680. assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5
  681. @pytest.mark.parametrize('is_rms_norm', [False, True])
  682. # @pytest.mark.parametrize('is_rms_norm', [False])
  683. @pytest.mark.parametrize('tied_norm', [False, True])
  684. # @pytest.mark.parametrize('tied_norm', [False])
  685. @pytest.mark.parametrize('has_residual', [True, False])
  686. # @pytest.mark.parametrize('has_residual', [False])
  687. @pytest.mark.parametrize('has_x1', [True, False])
  688. # @pytest.mark.parametrize('has_x1', [True])
  689. @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
  690. # @pytest.mark.parametrize('dropout_p', [0.0])
  691. @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
  692. # @pytest.mark.parametrize('weight_dtype', [torch.float16])
  693. @pytest.mark.parametrize('input_dtype,residual_dtype',
  694. [(torch.float16, torch.float16), (torch.float16, torch.float32),
  695. (torch.float32, torch.float32)]
  696. + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
  697. # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
  698. @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
  699. # @pytest.mark.parametrize('hidden_size', [256])
  700. def test_dropout_layer_norm_parallel_residual_prenorm_training(
  701. hidden_size, input_dtype, residual_dtype, weight_dtype,
  702. dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
  703. ):
  704. if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
  705. pytest.skip() # Not supported
  706. if is_rms_norm and fused_rms_norm_affine is None:
  707. pytest.skip() # We need Apex's FusedRMSNorm to test
  708. our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
  709. else dropout_add_rms_norm_parallel_residual)
  710. device = 'cuda'
  711. # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
  712. rtol, atol = (1e-3, 1e-4)
  713. # set seed
  714. torch.random.manual_seed(0)
  715. batch_size = 8
  716. seqlen = 512
  717. x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  718. requires_grad=True)
  719. x0 = x0_pt.detach().clone().requires_grad_()
  720. x0_ref = x0_pt.detach().clone().float().requires_grad_()
  721. if has_x1:
  722. x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
  723. requires_grad=True)
  724. x1 = x1_pt.detach().clone().requires_grad_()
  725. x1_ref = x1_pt.detach().clone().float().requires_grad_()
  726. else:
  727. x1 = None
  728. if has_residual:
  729. res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
  730. res = res_pt.detach().clone().requires_grad_()
  731. res_ref = res_pt.detach().clone().float().requires_grad_()
  732. else:
  733. res = None
  734. weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  735. bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  736. if not is_rms_norm else None)
  737. weight0_pt = weight0.detach().clone().requires_grad_()
  738. weight0_ref = weight0.detach().clone().float().requires_grad_()
  739. bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
  740. bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
  741. if not tied_norm:
  742. weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  743. bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
  744. if not is_rms_norm else None)
  745. weight1_pt = weight1.detach().clone().requires_grad_()
  746. weight1_ref = weight1.detach().clone().float().requires_grad_()
  747. bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
  748. bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
  749. else:
  750. weight1, bias1 = None, None
  751. epsilon = 1e-5
  752. residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
  753. out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(
  754. x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
  755. epsilon, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
  756. )
  757. assert out0.dtype == input_dtype
  758. if not tied_norm:
  759. assert out1.dtype == input_dtype
  760. print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
  761. if has_residual:
  762. if has_x1:
  763. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  764. + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
  765. + res_pt.float()).to(dtype=residual_dtype)
  766. residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
  767. + (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
  768. else:
  769. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  770. + res_pt.float()).to(dtype=residual_dtype)
  771. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
  772. else:
  773. if has_x1:
  774. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
  775. + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  776. residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
  777. + (x1_ref * dmask1.float()) / (1 - dropout_p))
  778. else:
  779. residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
  780. residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
  781. if not is_rms_norm:
  782. out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
  783. eps=epsilon).to(dtype=input_dtype)
  784. out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
  785. if not tied_norm:
  786. out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
  787. bias1_pt, eps=epsilon).to(dtype=input_dtype)
  788. out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
  789. else:
  790. out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
  791. eps=epsilon).to(dtype=input_dtype)
  792. out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
  793. if not tied_norm:
  794. out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
  795. (hidden_size,), eps=epsilon).to(dtype=input_dtype)
  796. out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
  797. assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
  798. if not tied_norm:
  799. assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
  800. assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
  801. g0 = torch.randn_like(out0) / batch_size
  802. if tied_norm:
  803. (out0 * F.sigmoid(residual)).backward(g0)
  804. (out0_pt * F.sigmoid(residual_pt)).backward(g0)
  805. (out0_ref * F.sigmoid(residual_ref)).backward(g0)
  806. else:
  807. g1 = torch.randn_like(out1) / batch_size
  808. (out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward()
  809. (out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward()
  810. (out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()
  811. assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
  812. if has_x1:
  813. assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
  814. if has_residual:
  815. assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
  816. assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
  817. if not is_rms_norm:
  818. assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
  819. if not tied_norm:
  820. assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
  821. if not is_rms_norm:
  822. assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5