flash_blocksparse_attention.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import math
  2. import hydra
  3. import torch
  4. import torch.nn as nn
  5. from einops import rearrange
  6. from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
  7. from flash_attn.flash_blocksparse_attn_interface import (
  8. convert_blockmask,
  9. flash_blocksparse_attn_func,
  10. )
  11. class FlashBlocksparseAttention(nn.Module):
  12. """Implement the scaled dot product attention with softmax.
  13. Arguments
  14. ---------
  15. softmax_temp: The temperature to use for the softmax attention.
  16. (default: 1/sqrt(d_keys) where d_keys is computed at
  17. runtime)
  18. attention_dropout: The dropout rate to apply to the attention
  19. (default: 0.1)
  20. """
  21. def __init__(
  22. self,
  23. sparsity_config,
  24. softmax_temp=None,
  25. attention_dropout=0.0,
  26. max_seq_length=2048,
  27. device=None,
  28. dtype=None,
  29. ):
  30. super().__init__()
  31. self.sparsity_config = hydra.utils.instantiate(sparsity_config)
  32. self.softmax_temp = softmax_temp
  33. self.dropout_p = attention_dropout
  34. # initialize sparse layout and register as buffer
  35. max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256
  36. layout = self.sparsity_config.make_layout(max_seq_length)
  37. self.register_buffer("layout", layout)
  38. blockmask_converted = convert_blockmask(self.layout, causal=False)
  39. self.register_buffer("blockmask_converted", blockmask_converted)
  40. # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
  41. def forward(
  42. self,
  43. qkv,
  44. attn_mask=None,
  45. key_padding_mask=None,
  46. causal=False,
  47. cu_seqlens=None,
  48. max_s=None,
  49. need_weights=False,
  50. convert_mask=True,
  51. ):
  52. """Implements the multihead softmax attention.
  53. Arguments
  54. ---------
  55. qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
  56. attn_mask: An implementation of BaseMask that encodes where each
  57. query can attend to
  58. key_padding_mask: An implementation of BaseMask that encodes how
  59. many query each sequence in the batch consists of
  60. """
  61. assert not need_weights
  62. assert attn_mask is None
  63. assert qkv.dtype == torch.float16
  64. assert qkv.is_cuda
  65. if cu_seqlens is None:
  66. batch_size = qkv.shape[0]
  67. seqlen = qkv.shape[1]
  68. # Convert mask to take a subset
  69. seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
  70. assert seqlen_rounded // 16 <= self.layout.shape[0], (
  71. seqlen_rounded // 256 <= self.layout.shape[1]
  72. )
  73. blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
  74. if key_padding_mask is None:
  75. qkv = rearrange(qkv, "b s ... -> (b s) ...")
  76. max_s = seqlen
  77. cu_seqlens = torch.arange(
  78. 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
  79. )
  80. output = flash_blocksparse_attn_func(
  81. qkv,
  82. cu_seqlens,
  83. blockmask,
  84. self.dropout_p if self.training else 0.0,
  85. max_s,
  86. softmax_scale=self.softmax_temp,
  87. causal=causal,
  88. )
  89. output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
  90. else:
  91. key_padding_mask_bool = key_padding_mask.bool_matrix
  92. nheads = qkv.shape[-2]
  93. x = rearrange(qkv, "b s three h d -> b s (three h d)")
  94. x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool)
  95. x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
  96. output_unpad = flash_blocksparse_attn_func(
  97. x_unpad,
  98. cu_seqlens,
  99. blockmask,
  100. self.dropout_p if self.training else 0.0,
  101. max_s,
  102. softmax_scale=self.softmax_temp,
  103. causal=causal,
  104. )
  105. output = rearrange(
  106. pad_input(
  107. rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
  108. ),
  109. "b s (h d) -> b s h d",
  110. h=nheads,
  111. )
  112. else:
  113. assert max_s is not None
  114. seqlen = max_s
  115. # Convert mask to take a subset
  116. seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
  117. assert seqlen_rounded // 16 <= self.layout.shape[0], (
  118. seqlen_rounded // 256 <= self.layout.shape[1]
  119. )
  120. blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
  121. if convert_mask:
  122. output = flash_blocksparse_attn_func(
  123. qkv,
  124. cu_seqlens,
  125. blockmask,
  126. self.dropout_p if self.training else 0.0,
  127. max_s,
  128. softmax_scale=self.softmax_temp,
  129. causal=causal,
  130. )
  131. else:
  132. output = flash_blocksparse_attn_func(
  133. qkv,
  134. cu_seqlens,
  135. self.blockmask_converted,
  136. self.dropout_p if self.training else 0.0,
  137. max_s,
  138. softmax_scale=self.softmax_temp,
  139. causal=causal,
  140. convert_mask=False,
  141. )
  142. return output, None
  143. class FlashBlocksparseMHA(nn.Module):
  144. def __init__(
  145. self,
  146. embed_dim,
  147. num_heads,
  148. sparsity_config,
  149. bias=True,
  150. batch_first=True,
  151. attention_dropout=0.0,
  152. causal=False,
  153. max_seq_length=2048,
  154. device=None,
  155. dtype=None,
  156. **kwargs,
  157. ) -> None:
  158. assert batch_first
  159. factory_kwargs = {"device": device, "dtype": dtype}
  160. super().__init__()
  161. self.embed_dim = embed_dim
  162. self.causal = causal
  163. self.num_heads = num_heads
  164. assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
  165. self.head_dim = self.embed_dim // num_heads
  166. assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
  167. self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
  168. self.inner_attn = FlashBlocksparseAttention(
  169. sparsity_config,
  170. attention_dropout=attention_dropout,
  171. max_seq_length=max_seq_length,
  172. **factory_kwargs,
  173. )
  174. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
  175. def forward(
  176. self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False
  177. ):
  178. qkv = self.Wqkv(x)
  179. qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
  180. context, attn_weights = self.inner_attn(
  181. qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
  182. )
  183. return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights