fused_dense.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688
  1. # Copyright (c) 2023, Tri Dao.
  2. # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
  3. # We make it work with pytorch amp and with bfloat16.
  4. # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
  5. from functools import partial
  6. from typing import Optional
  7. # import fused_dense_cuda # from apex
  8. import fused_dense_lib as fused_dense_cuda
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from torch import Tensor
  13. from torch.cuda.amp import custom_bwd, custom_fwd
  14. from torch.distributed import ProcessGroup
  15. from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
  16. from flash_attn.utils.distributed import (
  17. all_gather_raw,
  18. all_reduce,
  19. all_reduce_raw,
  20. reduce_scatter,
  21. reduce_scatter_raw,
  22. )
  23. class FusedDenseFunc(torch.autograd.Function):
  24. @staticmethod
  25. @custom_fwd
  26. def forward(
  27. ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True
  28. ):
  29. """
  30. If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
  31. with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
  32. """
  33. ctx.compute_weight_gradient = weight.requires_grad
  34. ctx.return_residual = return_residual
  35. ctx.process_group = process_group
  36. ctx.sequence_parallel = sequence_parallel
  37. if torch.is_autocast_enabled():
  38. x = x.to(dtype=torch.get_autocast_gpu_dtype())
  39. x = x.contiguous()
  40. if process_group is not None and sequence_parallel:
  41. # We want to kick off the all_gather early, before weight dtype conversion
  42. total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
  43. else:
  44. total_x = x
  45. if torch.is_autocast_enabled():
  46. weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
  47. bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
  48. weight = weight.contiguous()
  49. if process_group is not None and sequence_parallel:
  50. handle_x.wait()
  51. batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
  52. batch_dim = batch_shape.numel()
  53. # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
  54. if min(batch_dim, n, *weight.shape) > 65535 * 32:
  55. raise RuntimeError("fused_dense only supports matrix dims <= 2M")
  56. output = F.linear(total_x, weight, bias)
  57. if ctx.compute_weight_gradient:
  58. ctx.save_for_backward(x, weight)
  59. else:
  60. ctx.save_for_backward(weight)
  61. return output if not return_residual else (output, x)
  62. @staticmethod
  63. @custom_bwd
  64. def backward(ctx, grad_output, *args):
  65. grad_output = grad_output.contiguous()
  66. if ctx.return_residual:
  67. (grad_input,) = args
  68. grad_input = grad_input.contiguous()
  69. process_group = ctx.process_group
  70. sequence_parallel = ctx.sequence_parallel
  71. if ctx.compute_weight_gradient:
  72. x, weight = ctx.saved_tensors
  73. if process_group is not None and sequence_parallel:
  74. total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
  75. else:
  76. total_x = x
  77. else:
  78. (weight,) = ctx.saved_tensors
  79. total_x = None
  80. batch_shape = grad_output.shape[:-1]
  81. batch_dim = batch_shape.numel()
  82. grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
  83. if ctx.needs_input_grad[0]:
  84. if not ctx.return_residual:
  85. grad_input = F.linear(grad_output, weight.t())
  86. else:
  87. grad_input = torch.addmm(
  88. grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight
  89. )
  90. grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
  91. if process_group is not None:
  92. reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
  93. grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
  94. else:
  95. grad_input = None
  96. if ctx.needs_input_grad[1]:
  97. assert ctx.compute_weight_gradient
  98. if process_group is not None and sequence_parallel:
  99. handle_x.wait()
  100. grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
  101. total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
  102. )
  103. else:
  104. grad_weight = None
  105. grad_bias = grad_output if ctx.needs_input_grad[2] else None
  106. if process_group is not None and ctx.needs_input_grad[0]:
  107. handle_grad_input.wait()
  108. return grad_input, grad_weight, grad_bias, None, None, None
  109. def fused_dense_func(
  110. x: Tensor,
  111. weight: Tensor,
  112. bias: Optional[Tensor] = None,
  113. return_residual: bool = False,
  114. process_group: Optional[ProcessGroup] = None,
  115. sequence_parallel: bool = True,
  116. ):
  117. dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
  118. x.dtype == torch.float32 and torch.is_autocast_enabled()
  119. )
  120. if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
  121. return FusedDenseFunc.apply(
  122. x, weight, bias, return_residual, process_group, sequence_parallel
  123. )
  124. else:
  125. assert process_group is None
  126. out = F.linear(x, weight, bias)
  127. return out if not return_residual else (out, x)
  128. class FusedDense(nn.Linear):
  129. def __init__(
  130. self,
  131. in_features: int,
  132. out_features: int,
  133. bias: bool = True,
  134. return_residual: bool = False,
  135. device=None,
  136. dtype=None,
  137. ) -> None:
  138. super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
  139. self.return_residual = return_residual
  140. def forward(self, x, process_group=None):
  141. """
  142. If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
  143. we do an all_gather of x before doing the matmul.
  144. """
  145. return fused_dense_func(
  146. x,
  147. self.weight,
  148. self.bias,
  149. return_residual=self.return_residual,
  150. process_group=process_group,
  151. )
  152. class ColumnParallelLinear(nn.Linear):
  153. def __init__(
  154. self,
  155. in_features: int,
  156. out_features: int,
  157. process_group: ProcessGroup,
  158. bias: bool = True,
  159. sequence_parallel=True,
  160. multiple_of=1,
  161. device=None,
  162. dtype=None,
  163. ) -> None:
  164. world_size = torch.distributed.get_world_size(process_group)
  165. if out_features % multiple_of:
  166. raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
  167. multiple = out_features // multiple_of
  168. # We want to split @multiple across world_size, but it could be an uneven split
  169. div = multiple // world_size
  170. mod = multiple % world_size
  171. # The first @mod ranks get @div + 1 copies, the rest get @div copies
  172. local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
  173. super().__init__(
  174. in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
  175. )
  176. self.process_group = process_group
  177. self.sequence_parallel = sequence_parallel
  178. def forward(self, x):
  179. # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
  180. # we do an all_gather of x before doing the matmul.
  181. # If not, then the input is already gathered.
  182. return fused_dense_func(
  183. x,
  184. self.weight,
  185. self.bias,
  186. process_group=self.process_group,
  187. sequence_parallel=self.sequence_parallel,
  188. )
  189. class RowParallelLinear(nn.Linear):
  190. def __init__(
  191. self,
  192. in_features: int,
  193. out_features: int,
  194. process_group: ProcessGroup,
  195. bias: bool = True,
  196. sequence_parallel=True,
  197. multiple_of=1,
  198. device=None,
  199. dtype=None,
  200. ) -> None:
  201. world_size = torch.distributed.get_world_size(process_group)
  202. rank = torch.distributed.get_rank(process_group)
  203. if in_features % multiple_of:
  204. raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
  205. multiple = in_features // multiple_of
  206. # We want to split @multiple across world_size, but it could be an uneven split
  207. div = multiple // world_size
  208. mod = multiple % world_size
  209. # The first @mod ranks get @div + 1 copies, the rest get @div copies
  210. local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
  211. # Only rank 0 will have bias
  212. super().__init__(
  213. local_multiple * multiple_of,
  214. out_features,
  215. bias=bias and rank == 0,
  216. device=device,
  217. dtype=dtype,
  218. )
  219. self.process_group = process_group
  220. self.sequence_parallel = sequence_parallel
  221. def forward(self, x):
  222. """
  223. We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
  224. a reduce_scatter of the result.
  225. """
  226. out = fused_dense_func(x, self.weight, self.bias)
  227. reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
  228. return reduce_fn(out, self.process_group)
  229. class FusedMLPFunc(torch.autograd.Function):
  230. @staticmethod
  231. @custom_fwd
  232. def forward(
  233. ctx,
  234. x,
  235. weight1,
  236. bias1,
  237. weight2,
  238. bias2,
  239. activation="gelu_approx",
  240. save_pre_act=True,
  241. return_residual=False,
  242. checkpoint_lvl=0,
  243. heuristic=0,
  244. process_group=None,
  245. sequence_parallel=True,
  246. ):
  247. """
  248. If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
  249. with sequence parallelism: we do an all_gather of x before doing the matmul.
  250. If sequence_parallel=False, then the input is already gathered.
  251. checkpoint_lvl:
  252. 0: no recomputation in the bwd
  253. 1: recompute gelu_out / relu_out in the bwd
  254. 2: recompute pre_act and gelu_out / relu_out in the bwd
  255. """
  256. assert -1 <= heuristic <= 4
  257. assert activation in ["gelu_approx", "relu", "sqrelu"]
  258. if activation == "sqrelu":
  259. assert heuristic == -1
  260. if not save_pre_act:
  261. checkpoint_lvl = 2
  262. assert checkpoint_lvl in [0, 1, 2]
  263. ctx.return_residual = return_residual
  264. ctx.process_group = process_group
  265. ctx.sequence_parallel = sequence_parallel
  266. ctx.checkpoint_lvl = checkpoint_lvl
  267. ctx.activation = activation
  268. ctx.heuristic = heuristic
  269. if torch.is_autocast_enabled():
  270. x = x.to(dtype=torch.get_autocast_gpu_dtype())
  271. x = x.contiguous()
  272. if process_group is not None and sequence_parallel:
  273. # We want to kick off the all_gather early, before weight dtype conversion
  274. total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
  275. else:
  276. total_x = x
  277. if torch.is_autocast_enabled():
  278. dtype = torch.get_autocast_gpu_dtype()
  279. weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
  280. bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
  281. bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
  282. weight1 = weight1.contiguous()
  283. bias1 = bias1.contiguous() if bias1 is not None else None
  284. weight2 = weight2.contiguous()
  285. bias2 = bias2.contiguous() if bias2 is not None else None
  286. if process_group is not None and sequence_parallel:
  287. handle_x.wait()
  288. batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
  289. batch_dim = batch_shape.numel()
  290. # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
  291. if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
  292. raise RuntimeError("fused_dense only supports matrix dims <= 2M")
  293. if heuristic == -1:
  294. pre_act = F.linear(total_x, weight1, bias1)
  295. activation_fn = (
  296. partial(F.gelu, approximate="tanh")
  297. if activation == "gelu_approx"
  298. else (sqrelu_fwd if activation == "sqrelu" else F.relu)
  299. )
  300. with torch.jit.fuser("fuser2"):
  301. output1 = activation_fn(pre_act)
  302. # This is before adding bias1
  303. # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
  304. # with torch.jit.fuser('fuser2'):
  305. # output1 = bias_gelu(pre_act, bias1)
  306. else:
  307. is_gelu = activation == "gelu_approx"
  308. output1, *rest = fused_dense_cuda.linear_act_forward(
  309. total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
  310. )
  311. if save_pre_act:
  312. pre_act = rest[0]
  313. output2 = F.linear(output1, weight2, bias2)
  314. if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"):
  315. # For RELU the pre_act is very small (just a bit-mask) so we just save it
  316. ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
  317. elif checkpoint_lvl == 1:
  318. ctx.save_for_backward(x, weight1, weight2, pre_act)
  319. elif checkpoint_lvl == 2:
  320. ctx.save_for_backward(x, weight1, weight2, bias1)
  321. output2 = output2.reshape(*batch_shape, output2.shape[-1])
  322. return output2 if not return_residual else (output2, x)
  323. @staticmethod
  324. @custom_bwd
  325. def backward(ctx, grad_output, *args):
  326. grad_output = grad_output.contiguous()
  327. checkpoint_lvl = ctx.checkpoint_lvl
  328. activation = ctx.activation
  329. activation_fn = (
  330. partial(F.gelu, approximate="tanh")
  331. if activation == "gelu_approx"
  332. else (sqrelu_fwd if activation == "sqrelu" else F.relu)
  333. )
  334. if ctx.return_residual:
  335. (grad_input,) = args
  336. grad_input = grad_input.contiguous()
  337. process_group = ctx.process_group
  338. sequence_parallel = ctx.sequence_parallel
  339. x, weight1, weight2, *rest = ctx.saved_tensors
  340. if process_group is None or not sequence_parallel:
  341. total_x = x
  342. batch_shape = grad_output.shape[:-1]
  343. batch_dim = batch_shape.numel()
  344. if checkpoint_lvl in [0, 1]:
  345. if process_group is not None and sequence_parallel:
  346. total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
  347. if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"):
  348. pre_act, output1 = rest
  349. elif checkpoint_lvl == 1:
  350. (pre_act,) = rest
  351. with torch.jit.fuser("fuser2"):
  352. output1 = activation_fn(pre_act)
  353. elif checkpoint_lvl == 2:
  354. (bias1,) = rest
  355. if process_group is not None and sequence_parallel:
  356. total_x, _ = all_gather_raw(x, process_group)
  357. if ctx.heuristic == -1:
  358. pre_act = F.linear(total_x, weight1, bias1)
  359. with torch.jit.fuser("fuser2"):
  360. output1 = activation_fn(pre_act)
  361. else:
  362. output1, pre_act = fused_dense_cuda.linear_act_forward(
  363. total_x.reshape(batch_dim, total_x.shape[-1]),
  364. weight1,
  365. bias1,
  366. activation == "gelu_approx",
  367. True,
  368. ctx.heuristic,
  369. )
  370. grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
  371. output1 = output1.reshape(batch_dim, output1.shape[-1])
  372. pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
  373. if ctx.needs_input_grad[3]:
  374. grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
  375. output1, grad_output, ctx.needs_input_grad[4]
  376. )
  377. else:
  378. grad_weight2 = None
  379. grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
  380. if ctx.heuristic == -1:
  381. # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
  382. grad_output1 = F.linear(grad_output, weight2.t())
  383. activation_grad_fn = (
  384. gelu_bwd
  385. if activation == "gelu_approx"
  386. else (sqrelu_bwd if activation == "sqrelu" else relu_bwd)
  387. )
  388. with torch.jit.fuser("fuser2"):
  389. grad_pre_act = activation_grad_fn(grad_output1, pre_act)
  390. else:
  391. # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
  392. # just compute gelu/relu grad
  393. grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
  394. weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic
  395. )
  396. if not ctx.needs_input_grad[2]:
  397. grad_bias1 = None
  398. if ctx.needs_input_grad[0]:
  399. if not ctx.return_residual:
  400. grad_input = F.linear(grad_pre_act, weight1.t())
  401. else:
  402. grad_input = torch.addmm(
  403. grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1
  404. )
  405. grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
  406. if process_group is not None:
  407. reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
  408. grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
  409. else:
  410. grad_input = None
  411. if ctx.heuristic == -1:
  412. if ctx.needs_input_grad[1]:
  413. if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
  414. handle_x.wait()
  415. grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
  416. total_x.reshape(batch_dim, total_x.shape[-1]),
  417. grad_pre_act,
  418. ctx.needs_input_grad[2],
  419. )
  420. else:
  421. grad_weight1 = None
  422. grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
  423. else:
  424. if ctx.needs_input_grad[1]:
  425. if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
  426. handle_x.wait()
  427. grad_weight1 = F.linear(
  428. grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()
  429. )
  430. else:
  431. grad_weight1 = None
  432. if process_group is not None and ctx.needs_input_grad[0]:
  433. handle_grad_input.wait()
  434. return (
  435. grad_input,
  436. grad_weight1,
  437. grad_bias1,
  438. grad_weight2,
  439. grad_bias2,
  440. None,
  441. None,
  442. None,
  443. None,
  444. None,
  445. None,
  446. None,
  447. )
  448. def fused_mlp_func(
  449. x: Tensor,
  450. weight1: Tensor,
  451. weight2: Tensor,
  452. bias1: Optional[Tensor] = None,
  453. bias2: Optional[Tensor] = None,
  454. activation: str = "gelu_approx",
  455. save_pre_act: bool = True,
  456. return_residual: bool = False,
  457. checkpoint_lvl: int = 0,
  458. heuristic: int = 0,
  459. process_group: Optional[ProcessGroup] = None,
  460. sequence_parallel: bool = True,
  461. ):
  462. assert activation in ["gelu_approx", "relu", "sqrelu"]
  463. dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
  464. x.dtype == torch.float32 and torch.is_autocast_enabled()
  465. )
  466. # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
  467. dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0)
  468. if (
  469. x.is_cuda
  470. and weight1.is_cuda
  471. and weight2.is_cuda
  472. and (bias1 is None or bias1.is_cuda)
  473. and (bias2 is None or bias2.is_cuda)
  474. and dtype_eligible
  475. and dim_eligible
  476. ):
  477. return FusedMLPFunc.apply(
  478. x,
  479. weight1,
  480. bias1,
  481. weight2,
  482. bias2,
  483. activation,
  484. save_pre_act,
  485. return_residual,
  486. checkpoint_lvl,
  487. heuristic,
  488. process_group,
  489. sequence_parallel,
  490. )
  491. else:
  492. assert process_group is None
  493. pre_act = F.linear(x, weight1, bias1)
  494. activation_fn = (
  495. partial(F.gelu, approximate="tanh")
  496. if activation == "gelu_approx"
  497. else partial(F.relu, inplace=True)
  498. )
  499. output1 = activation_fn(pre_act)
  500. output2 = F.linear(output1, weight2, bias2)
  501. return output2 if not return_residual else (output2, x)
  502. class FusedMLP(nn.Module):
  503. def __init__(
  504. self,
  505. in_features,
  506. hidden_features=None,
  507. out_features=None,
  508. bias1=True,
  509. bias2=True,
  510. activation="gelu_approx",
  511. return_residual=False,
  512. checkpoint_lvl=0,
  513. heuristic="auto",
  514. device=None,
  515. dtype=None,
  516. ):
  517. """
  518. If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
  519. we do an all_gather of x before doing the matmul, gelu, then matmul.
  520. Finally we do a reduce_scatter of the output.
  521. checkpoint_lvl (increasing lvl means slower but more memory saving):
  522. 0: no recomputation in the bwd
  523. 1: recompute gelu_out in the bwd
  524. 2: recompute pre_act and gelu_out in the bwd
  525. heuristic:
  526. -1: don't fuse gemm + gelu (separate kernel)
  527. 0..4: use this heuristic for the algo section in the fused gemm + gelu
  528. 'auto': heuristic will be picked automatically:
  529. For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
  530. For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
  531. For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
  532. is slower than the unfused version.
  533. return_residual: whether to return the input x along with the output. This is for
  534. performance reason: for post-norm architecture, returning the input allows us
  535. to fuse the backward of nn.Linear with the residual connection.
  536. """
  537. assert checkpoint_lvl in [0, 1, 2]
  538. assert activation in ["gelu_approx", "relu", "sqrelu"]
  539. factory_kwargs = {"device": device, "dtype": dtype}
  540. super().__init__()
  541. out_features = out_features or in_features
  542. hidden_features = hidden_features or in_features * 4
  543. self.activation = activation
  544. self.return_residual = return_residual
  545. self.checkpoint_lvl = checkpoint_lvl
  546. self.heuristic = heuristic if activation != "sqrelu" else -1
  547. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
  548. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
  549. def forward(self, x, process_group=None):
  550. dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
  551. if self.heuristic == "auto":
  552. if self.activation == "gelu_approx":
  553. if torch.cuda.get_device_capability("cuda") == (9, 0):
  554. heuristic = -1
  555. else:
  556. cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
  557. heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
  558. else:
  559. heuristic = 0
  560. else:
  561. heuristic = self.heuristic
  562. out = fused_mlp_func(
  563. x,
  564. self.fc1.weight,
  565. self.fc2.weight,
  566. self.fc1.bias,
  567. self.fc2.bias,
  568. activation=self.activation,
  569. save_pre_act=self.training,
  570. return_residual=self.return_residual,
  571. checkpoint_lvl=self.checkpoint_lvl,
  572. heuristic=heuristic,
  573. process_group=process_group,
  574. )
  575. if self.return_residual:
  576. out, x = out
  577. if process_group is not None:
  578. out = reduce_scatter(out, process_group)
  579. return out if not self.return_residual else (out, x)
  580. class ParallelFusedMLP(nn.Module):
  581. def __init__(
  582. self,
  583. in_features,
  584. hidden_features=None,
  585. out_features=None,
  586. activation="gelu_approx",
  587. process_group: ProcessGroup = None,
  588. bias1=True,
  589. bias2=True,
  590. sequence_parallel=True,
  591. checkpoint_lvl=0,
  592. heuristic="auto",
  593. device=None,
  594. dtype=None,
  595. ):
  596. """
  597. process_group is required. We're doing Tensor Parallel with sequence parallelism:
  598. we do an all_gather of x before doing the matmul, gelu, then matmul.
  599. Finally we do a reduce_scatter of the output.
  600. checkpoint_lvl (increasing lvl means slower but more memory saving):
  601. 0: no recomputation in the bwd
  602. 1: recompute gelu_out in the bwd
  603. 2: recompute pre_act and gelu_out in the bwd
  604. heuristic:
  605. -1: don't fuse gemm + gelu (separate kernel)
  606. 0..4: use this heuristic for the algo section in the fused gemm + gelu
  607. 'auto': heuristic will be picked automatically:
  608. For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
  609. For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
  610. """
  611. assert checkpoint_lvl in [0, 1, 2]
  612. assert activation in ["gelu_approx", "relu", "sqrelu"]
  613. assert process_group is not None
  614. factory_kwargs = {"device": device, "dtype": dtype}
  615. super().__init__()
  616. out_features = out_features or in_features
  617. hidden_features = hidden_features or in_features * 4
  618. self.activation = activation
  619. self.process_group = process_group
  620. self.sequence_parallel = sequence_parallel
  621. self.checkpoint_lvl = checkpoint_lvl
  622. self.heuristic = heuristic if activation != "sqrelu" else -1
  623. self.fc1 = ColumnParallelLinear(
  624. in_features, hidden_features, process_group, bias=bias1, **factory_kwargs
  625. )
  626. self.fc2 = RowParallelLinear(
  627. hidden_features, out_features, process_group, bias=bias2, **factory_kwargs
  628. )
  629. def forward(self, x):
  630. dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
  631. if self.heuristic == "auto":
  632. if self.activation == "gelu_approx":
  633. cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
  634. heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
  635. else:
  636. heuristic = 0
  637. else:
  638. heuristic = self.heuristic
  639. out = fused_mlp_func(
  640. x,
  641. self.fc1.weight,
  642. self.fc2.weight,
  643. self.fc1.bias,
  644. self.fc2.bias,
  645. activation=self.activation,
  646. save_pre_act=self.training,
  647. checkpoint_lvl=self.checkpoint_lvl,
  648. heuristic=heuristic,
  649. process_group=self.process_group,
  650. sequence_parallel=self.sequence_parallel,
  651. )
  652. reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
  653. return reduce_fn(out, self.process_group)