fused_dense.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
  2. # We make it work with pytorch amp and with bfloat16.
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch import Tensor
  8. from torch.cuda.amp import custom_bwd, custom_fwd
  9. # import fused_dense_cuda # from apex
  10. import fused_dense_lib as fused_dense_cuda
  11. from flash_attn.ops.gelu_activation import gelu_bwd
  12. class FusedDenseFunc(torch.autograd.Function):
  13. @staticmethod
  14. @custom_fwd
  15. def forward(ctx, x, weight, bias, return_residual=False):
  16. if torch.is_autocast_enabled():
  17. dtype = torch.get_autocast_gpu_dtype()
  18. x, weight = [a.to(dtype=dtype) for a in [x, weight]]
  19. bias = bias.to(dtype=dtype) if bias is not None else None
  20. ctx.return_residual = return_residual
  21. x = x.contiguous()
  22. weight = weight.contiguous()
  23. ctx.save_for_backward(x, weight)
  24. batch_shape, n = x.shape[:-1], x.shape[-1]
  25. batch_dim = batch_shape.numel()
  26. assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
  27. output = F.linear(x, weight, bias)
  28. return output if not return_residual else (output, x)
  29. @staticmethod
  30. @custom_bwd
  31. def backward(ctx, grad_output, *args):
  32. grad_output = grad_output.contiguous()
  33. if ctx.return_residual:
  34. grad_input, = args
  35. grad_input = grad_input.contiguous()
  36. x, weight = ctx.saved_tensors
  37. batch_shape, n = x.shape[:-1], x.shape[-1]
  38. batch_dim = batch_shape.numel()
  39. grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
  40. if ctx.needs_input_grad[1]:
  41. grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
  42. x.reshape(batch_dim, n), grad_output, ctx.needs_input_grad[2]
  43. )
  44. else:
  45. grad_weight = None
  46. grad_bias = grad_output if ctx.needs_input_grad[2] else None
  47. if ctx.needs_input_grad[0]:
  48. if not ctx.return_residual:
  49. grad_input = F.linear(grad_output, weight.t())
  50. else:
  51. grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_output, weight)
  52. grad_input = grad_input.reshape_as(x)
  53. else:
  54. grad_input = None
  55. return grad_input, grad_weight, grad_bias, None
  56. def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
  57. return_residual: bool = False):
  58. batch_dim = x.shape[:-1].numel()
  59. dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
  60. or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
  61. if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024
  62. and dtype_eligible):
  63. return FusedDenseFunc.apply(x, weight, bias, return_residual)
  64. else:
  65. out = F.linear(x, weight, bias)
  66. return out if not return_residual else (out, x)
  67. class FusedDense(nn.Linear):
  68. def __init__(self, in_features: int, out_features: int, bias: bool = True,
  69. return_residual: bool = False, device=None, dtype=None) -> None:
  70. super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
  71. self.return_residual = return_residual
  72. def forward(self, x):
  73. return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual)
  74. class FusedDenseGeluDenseFunc(torch.autograd.Function):
  75. @staticmethod
  76. @custom_fwd
  77. def forward(ctx, x, weight1, bias1, weight2, bias2, save_gelu_in=True, return_residual=False,
  78. checkpoint_lvl=0, heuristic=0):
  79. """checkpoint_lvl:
  80. 0: no recomputation in the bwd
  81. 1: recompute gelu_out in the bwd
  82. 2: recompute gelu_in and gelu_out in the bwd
  83. """
  84. assert -1 <= heuristic <= 4
  85. if torch.is_autocast_enabled():
  86. dtype = torch.get_autocast_gpu_dtype()
  87. x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
  88. bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
  89. bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
  90. if not save_gelu_in:
  91. checkpoint_lvl = 2
  92. assert checkpoint_lvl in [0, 1, 2]
  93. ctx.return_residual = return_residual
  94. x = x.contiguous()
  95. weight1 = weight1.contiguous()
  96. bias1 = bias1.contiguous() if bias1 is not None else None
  97. weight2 = weight2.contiguous()
  98. bias2 = bias2.contiguous() if bias2 is not None else None
  99. batch_shape, n = x.shape[:-1], x.shape[-1]
  100. batch_dim = batch_shape.numel()
  101. assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
  102. if heuristic == -1:
  103. gelu_in = F.linear(x, weight1, bias1)
  104. output1 = F.gelu(gelu_in, approximate='tanh')
  105. # gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
  106. # with torch.jit.fuser('fuser2'):
  107. # output1 = bias_gelu(gelu_in, bias1)
  108. else:
  109. output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
  110. bias1, save_gelu_in, heuristic)
  111. if save_gelu_in:
  112. gelu_in = rest[0]
  113. output2 = F.linear(output1, weight2, bias2)
  114. ctx.checkpoint_lvl = checkpoint_lvl
  115. ctx.heuristic = heuristic
  116. if checkpoint_lvl == 0:
  117. ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
  118. elif checkpoint_lvl == 1:
  119. ctx.save_for_backward(x, weight1, weight2, gelu_in)
  120. elif checkpoint_lvl == 2:
  121. ctx.save_for_backward(x, weight1, weight2, bias1)
  122. output2 = output2.reshape(*batch_shape, output2.shape[-1])
  123. return output2 if not return_residual else (output2, x)
  124. @staticmethod
  125. @custom_bwd
  126. def backward(ctx, grad_output, *args):
  127. grad_output = grad_output.contiguous()
  128. checkpoint_lvl = ctx.checkpoint_lvl
  129. if ctx.return_residual:
  130. grad_input, = args
  131. grad_input = grad_input.contiguous()
  132. x, weight1, weight2, *rest = ctx.saved_tensors
  133. batch_shape, n = x.shape[:-1], x.shape[-1]
  134. batch_dim = batch_shape.numel()
  135. if checkpoint_lvl == 0:
  136. gelu_in, output1 = rest
  137. elif checkpoint_lvl == 1:
  138. gelu_in, = rest
  139. output1 = F.gelu(gelu_in, approximate='tanh')
  140. elif checkpoint_lvl == 2:
  141. bias1, = rest
  142. if ctx.heuristic == -1:
  143. gelu_in = F.linear(x, weight1, bias1)
  144. output1 = F.gelu(gelu_in, approximate='tanh')
  145. else:
  146. output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
  147. x.reshape(batch_dim, n), weight1, bias1, True, ctx.heuristic
  148. )
  149. grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
  150. output1 = output1.reshape(batch_dim, output1.shape[-1])
  151. gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1])
  152. if ctx.needs_input_grad[3]:
  153. grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
  154. output1, grad_output, ctx.needs_input_grad[4]
  155. )
  156. else:
  157. grad_weight2 = None
  158. grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
  159. if ctx.heuristic == -1:
  160. # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
  161. grad_output1 = F.linear(grad_output, weight2.t())
  162. with torch.jit.fuser('fuser2'):
  163. grad_gelu = gelu_bwd(grad_output1, gelu_in)
  164. if ctx.needs_input_grad[1]:
  165. grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
  166. x.reshape(batch_dim, n), grad_gelu, ctx.needs_input_grad[2]
  167. )
  168. else:
  169. grad_weight1 = None
  170. grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
  171. else:
  172. # The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
  173. # just compute gelu grad
  174. grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad(
  175. weight2, grad_output, gelu_in, ctx.heuristic
  176. )
  177. if not ctx.needs_input_grad[2]:
  178. grad_bias1 = None
  179. if ctx.needs_input_grad[1]:
  180. grad_weight1 = F.linear(grad_gelu.t(), x.reshape(batch_dim, n).t())
  181. else:
  182. grad_weight1 = None
  183. if ctx.needs_input_grad[0]:
  184. if not ctx.return_residual:
  185. grad_input = F.linear(grad_gelu, weight1.t())
  186. else:
  187. grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_gelu, weight1)
  188. grad_input = grad_input.reshape_as(x)
  189. else:
  190. grad_input = None
  191. return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None, None, None
  192. def fused_dense_gelu_dense_func(
  193. x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
  194. bias2: Optional[Tensor] = None,
  195. save_gelu_in: bool = True, return_residual: bool = False,
  196. checkpoint_lvl: int = 0, heuristic: int = 0
  197. ):
  198. batch_dim = x.shape[:-1].numel()
  199. dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
  200. or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
  201. if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
  202. and (bias2 is None or bias2.is_cuda) and batch_dim <= 64 * 1024
  203. and dtype_eligible):
  204. return FusedDenseGeluDenseFunc.apply(
  205. x, weight1, bias1, weight2, bias2,
  206. save_gelu_in, return_residual, checkpoint_lvl, heuristic
  207. )
  208. else:
  209. gelu_in = F.linear(x, weight1, bias1)
  210. output1 = F.gelu(gelu_in, approximate='tanh')
  211. output2 = F.linear(output1, weight2, bias2)
  212. return output2 if not return_residual else (output2, x)
  213. class FusedDenseGeluDense(nn.Module):
  214. def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
  215. bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
  216. device=None, dtype=None):
  217. """
  218. checkpoint_lvl (increasing lvl means slower but more memory saving):
  219. 0: no recomputation in the bwd
  220. 1: recompute gelu_out in the bwd
  221. 2: recompute gelu_in and gelu_out in the bwd
  222. heuristic:
  223. -1: don't fuse gemm + gelu (separate kernel)
  224. 0..4: use this heuristic for the algo section in the fused gemm + gelu
  225. For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
  226. For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
  227. return_residual: whether to return the input x along with the output. This is for
  228. performance reason: for post-norm architecture, returning the input allows us
  229. to fuse the backward of nn.Linear with the residual connection.
  230. """
  231. assert checkpoint_lvl in [0, 1, 2]
  232. factory_kwargs = {'device': device, 'dtype': dtype}
  233. super().__init__()
  234. if out_features is None:
  235. out_features = in_features
  236. self.return_residual = return_residual
  237. self.checkpoint_lvl = checkpoint_lvl
  238. self.heuristic = heuristic
  239. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
  240. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
  241. def forward(self, x):
  242. return fused_dense_gelu_dense_func(
  243. x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
  244. save_gelu_in=self.training, return_residual=self.return_residual,
  245. checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic
  246. )