123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- import torch
- from apex._autocast_utils import _cast_if_autocast_enabled
- from apex.transformer.enums import AttnMaskType
- from fused_softmax_lib import (
- scaled_masked_softmax_backward,
- scaled_masked_softmax_forward,
- scaled_masked_softmax_get_batch_per_block,
- scaled_upper_triang_masked_softmax_backward,
- scaled_upper_triang_masked_softmax_forward,
- )
- class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
- """
- Fused operation which performs following three operations in sequence
- 1. Scale the tensor.
- 2. Apply upper triangular mask (typically used in gpt models).
- 3. Perform softmax.
- """
- @staticmethod
- def forward(ctx, inputs, scale):
- scale_t = torch.tensor([scale])
- softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
- ctx.save_for_backward(softmax_results, scale_t)
- return softmax_results
- @staticmethod
- def backward(ctx, output_grads):
- softmax_results, scale_t = ctx.saved_tensors
- input_grads = scaled_upper_triang_masked_softmax_backward(
- output_grads, softmax_results, scale_t[0]
- )
- return input_grads, None
- def scaled_upper_triang_masked_softmax(inputs, _, scale):
- b, np, sq, sk = inputs.size()
- assert sq == sk, "causal mask is only for self attention"
-
- inputs = inputs.view(-1, sq, sk)
- args = _cast_if_autocast_enabled(inputs, scale)
- with torch.cuda.amp.autocast(enabled=False):
- probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
- return probs.view(b, np, sq, sk)
- class ScaledMaskedSoftmax(torch.autograd.Function):
- @staticmethod
- def forward(ctx, inputs, mask, scale):
- scale_t = torch.tensor([scale])
- softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0])
- ctx.save_for_backward(softmax_results, scale_t)
- return softmax_results
- @staticmethod
- def backward(ctx, output_grads):
- softmax_results, scale_t = ctx.saved_tensors
- input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
- return input_grads, None, None
- def scaled_masked_softmax(inputs, mask, scale):
-
- args = _cast_if_autocast_enabled(inputs, mask, scale)
- with torch.cuda.amp.autocast(enabled=False):
- return ScaledMaskedSoftmax.apply(*args)
- class FusedScaleMaskSoftmax(torch.nn.Module):
- """
- fused operation: scaling + mask + softmax
- Arguments:
- input_in_fp16: flag to indicate if input in fp16 data format.
- input_in_bf16: flag to indicate if input in bf16 data format.
- attn_mask_type: attention mask type (pad or causal)
- scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
- mask_func: mask function to be applied.
- softmax_in_fp32: if true, softmax in performed at fp32 precision.
- scale: scaling factor used in input tensor scaling.
- """
- def __init__(
- self,
- input_in_fp16,
- input_in_bf16,
- attn_mask_type,
- scaled_masked_softmax_fusion,
- mask_func,
- softmax_in_fp32,
- scale,
- ):
- super().__init__()
- self.input_in_fp16 = input_in_fp16
- self.input_in_bf16 = input_in_bf16
- if self.input_in_fp16 and self.input_in_bf16:
- raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
- self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
- self.attn_mask_type = attn_mask_type
- self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
- self.mask_func = mask_func
- self.softmax_in_fp32 = softmax_in_fp32
- self.scale = scale
- if not (self.scale is None or softmax_in_fp32):
- raise RuntimeError("softmax should be in fp32 when scaled")
- if self.scaled_masked_softmax_fusion:
- if self.attn_mask_type == AttnMaskType.causal:
- self.fused_softmax_func = scaled_upper_triang_masked_softmax
- elif self.attn_mask_type == AttnMaskType.padding:
- self.fused_softmax_func = scaled_masked_softmax
- else:
- raise ValueError("Invalid attn_mask_type.")
- def forward(self, input, mask):
-
- assert input.dim() == 4
- if self.is_kernel_available(mask, *input.size()):
- return self.forward_fused_softmax(input, mask)
- else:
- return self.forward_torch_softmax(input, mask)
- def is_kernel_available(self, mask, b, np, sq, sk):
- attn_batches = b * np
- if (
- self.scaled_masked_softmax_fusion
- and self.input_in_float16
- and (
- self.attn_mask_type == AttnMaskType.causal
- or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
- )
- and 16 < sk <= 8192
- and sq % 4 == 0
- and sk % 4 == 0
- and attn_batches % 4 == 0
- ):
- if 0 <= sk <= 8192:
- batch_per_block = self.get_batch_per_block(sq, sk, b, np)
- if self.attn_mask_type == AttnMaskType.causal:
- if attn_batches % batch_per_block == 0:
- return True
- else:
- if sq % batch_per_block == 0:
- return True
- return False
- def forward_fused_softmax(self, input, mask):
-
- scale = self.scale if self.scale is not None else 1.0
- return self.fused_softmax_func(input, mask, scale)
- def forward_torch_softmax(self, input, mask):
- if self.input_in_float16 and self.softmax_in_fp32:
- input = input.float()
- if self.scale is not None:
- input = input * self.scale
- mask_output = self.mask_func(input, mask) if mask is not None else input
- probs = torch.nn.Softmax(dim=-1)(mask_output)
- if self.input_in_float16 and self.softmax_in_fp32:
- if self.input_in_fp16:
- probs = probs.half()
- else:
- probs = probs.bfloat16()
- return probs
- @staticmethod
- def get_batch_per_block(sq, sk, b, np):
- return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np)
|