mlp.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
  2. # to naive implementation.
  3. import fused_dense_lib as fused_dense_cuda
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.cuda.amp import custom_bwd, custom_fwd
  8. from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd
  9. from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act
  10. class FusedDenseSqreluDenseFunc(torch.autograd.Function):
  11. @staticmethod
  12. @custom_fwd
  13. def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
  14. """checkpoint_lvl:
  15. 0: no recomputation in the bwd
  16. 1: recompute gelu_out in the bwd
  17. 2: recompute act_input and gelu_out in the bwd
  18. """
  19. if torch.is_autocast_enabled():
  20. dtype = torch.get_autocast_gpu_dtype()
  21. x, weight1, bias1, weight2, bias2 = [
  22. a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2]
  23. ]
  24. is_bf16 = x.dtype == torch.bfloat16
  25. assert checkpoint_lvl in [0, 1, 2]
  26. x = x.contiguous()
  27. weight1 = weight1.contiguous()
  28. bias1 = bias1.contiguous()
  29. weight2 = weight2.contiguous()
  30. bias2 = bias2.contiguous()
  31. batch_shape, n = x.shape[:-1], x.shape[-1]
  32. batch_dim = batch_shape.numel()
  33. if is_bf16:
  34. act_input = fused_dense_cuda.linear_bias_forward(
  35. x.reshape(batch_dim, n), weight1, bias1
  36. )
  37. output1 = sqrelu_fwd(act_input)
  38. else:
  39. save_act_input = checkpoint_lvl != 2
  40. result = triton_linear_act(
  41. x.reshape(batch_dim, n),
  42. weight1,
  43. bias1,
  44. activation="squared_relu",
  45. save_act_input=save_act_input,
  46. )
  47. if save_act_input:
  48. output1, act_input = result
  49. else:
  50. output1 = result
  51. output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
  52. ctx.checkpoint_lvl = checkpoint_lvl
  53. if checkpoint_lvl == 0:
  54. ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1)
  55. elif checkpoint_lvl == 1:
  56. ctx.save_for_backward(x, weight1, bias1, weight2, act_input)
  57. elif checkpoint_lvl == 2:
  58. ctx.save_for_backward(x, weight1, bias1, weight2)
  59. return output2.reshape(*batch_shape, output2.shape[-1])
  60. @staticmethod
  61. @custom_bwd
  62. def backward(ctx, grad_output):
  63. grad_output = grad_output.contiguous()
  64. checkpoint_lvl = ctx.checkpoint_lvl
  65. x, weight1, bias1, weight2, *rest = ctx.saved_tensors
  66. batch_shape, n = x.shape[:-1], x.shape[-1]
  67. batch_dim = batch_shape.numel()
  68. is_bf16 = x.dtype == torch.bfloat16
  69. if checkpoint_lvl == 0:
  70. act_input, output1 = rest
  71. elif checkpoint_lvl == 1:
  72. (act_input,) = rest
  73. output1 = sqrelu_fwd(act_input)
  74. elif checkpoint_lvl == 2:
  75. if is_bf16:
  76. act_input = fused_dense_cuda.linear_bias_forward(
  77. x.reshape(batch_dim, n), weight1, bias1
  78. )
  79. output1 = sqrelu_fwd(act_input)
  80. else:
  81. output1, act_input = triton_linear_act(
  82. x.reshape(batch_dim, n),
  83. weight1,
  84. bias1,
  85. activation="squared_relu",
  86. save_act_input=True,
  87. )
  88. if is_bf16:
  89. grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
  90. grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
  91. grad_output1 = grad_output @ weight2
  92. grad_act_input = sqrelu_bwd(grad_output1, act_input)
  93. grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
  94. x.reshape(batch_dim, n), weight1, grad_act_input
  95. )
  96. else:
  97. grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
  98. grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
  99. grad_act_input = triton_dgrad_act(
  100. grad_output, weight2, activation="squared_relu", act_input=act_input
  101. )
  102. grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
  103. x.reshape(batch_dim, n), weight1, grad_act_input
  104. )
  105. return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None
  106. fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
  107. class FusedDenseSqreluDense(nn.Module):
  108. def __init__(
  109. self,
  110. in_features,
  111. hidden_features=None,
  112. out_features=None,
  113. bias1=True,
  114. bias2=True,
  115. checkpoint_lvl=0,
  116. device=None,
  117. dtype=None,
  118. ):
  119. """
  120. checkpoint_lvl (increasing lvl means slower but more memory saving):
  121. 0: no recomputation in the bwd
  122. 1: recompute gelu_out in the bwd
  123. 2: recompute gelu_in and gelu_out in the bwd
  124. """
  125. assert checkpoint_lvl in [0, 1, 2]
  126. factory_kwargs = {"device": device, "dtype": dtype}
  127. super().__init__()
  128. out_features = out_features or in_features
  129. hidden_features = hidden_features or in_features * 4
  130. assert bias1 == True, "DenseSqreluDense module without bias is currently not supported"
  131. assert bias2 == True, "DenseSqreluDense module without bias is currently not supported"
  132. self.checkpoint_lvl = checkpoint_lvl
  133. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
  134. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
  135. def forward(self, x):
  136. assert x.is_cuda
  137. return fused_dense_sqrelu_dense_function(
  138. x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl
  139. )