utils.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Helper functions for 3D sparse pattern
  2. # These function are not optimized and very inefficient.
  3. # Avoid calling them too frequent or use a cache mechanism.
  4. from functools import lru_cache
  5. import torch
  6. import triton
  7. from scipy import sparse
  8. def dense_to_crow_col(x: torch.Tensor):
  9. """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
  10. NOTE: col_indices padded -1
  11. """
  12. device = x.device
  13. pad = -1
  14. dim = x.dim()
  15. assert x.dim() in (2, 3)
  16. if x.dim() == 2:
  17. x = x[None]
  18. x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
  19. crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
  20. cols = [torch.from_numpy(xi.indices) for xi in x]
  21. max_cols = max(len(xi) for xi in cols)
  22. cols = [
  23. torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
  24. for xi in cols
  25. ]
  26. cols = torch.vstack(cols)
  27. if dim == 2:
  28. crows = crows[0]
  29. cols = cols[0]
  30. return crows.to(device), cols.to(device)
  31. def crow_col_to_dense(crows: torch.Tensor,
  32. cols: torch.Tensor,
  33. dtype: torch.dtype = torch.float16):
  34. dim = crows.dim()
  35. if dim == 1:
  36. crows = crows[None]
  37. cols = cols[None]
  38. device = crows.device
  39. crows, cols = crows.cpu(), cols.cpu() # faster in cpu
  40. shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
  41. x = torch.zeros(shape, dtype=dtype)
  42. for i in range(shape[0]):
  43. for j in range(shape[1]):
  44. x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
  45. if dim == 1:
  46. x = x[0]
  47. return x.to(device)
  48. def dense_to_ccol_row(x: torch.Tensor):
  49. """Similar, but to CSC format"""
  50. x = x.transpose(-2, -1)
  51. return dense_to_crow_col(x)
  52. def ccol_row_to_dense(ccol: torch.Tensor,
  53. rows: torch.Tensor,
  54. dtype: torch.dtype = torch.float16):
  55. return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
  56. def _get_sparse_attn_mask_homo_head(
  57. q_len: int,
  58. max_seqlen: int,
  59. dtype: torch.dtype,
  60. device: torch.device,
  61. block_size: int = 128,
  62. local_blocks: int = 4,
  63. vert_stride: int = 4,
  64. return_dense: bool = False,
  65. ):
  66. """
  67. :return: a tuple of 3:
  68. - tuple of crow_indices, col_indices representation
  69. of CSR format.
  70. - block dense mask
  71. - all token dense mask (be aware that it can be
  72. OOM if it is too big) if `return_dense==True`,
  73. otherwise, None
  74. """
  75. with torch.no_grad():
  76. num_blocks = triton.cdiv(max_seqlen, block_size)
  77. q_pos = torch.arange(num_blocks)[:, None]
  78. k_pos = torch.arange(num_blocks)[None]
  79. mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
  80. block_mask_dense = (((q_pos >= k_pos)
  81. & ((q_pos - k_pos < local_blocks)
  82. | mask_vert_strided)).to(device).to(dtype))
  83. num_blocks_q = triton.cdiv(q_len, block_size)
  84. block_mask_dense_output = (dense_to_crow_col(
  85. block_mask_dense[-num_blocks_q:].contiguous()))
  86. if return_dense:
  87. mask_dense = torch.kron(
  88. block_mask_dense,
  89. block_mask_dense.new_ones((block_size, block_size)),
  90. )
  91. causal_mask = torch.tril(torch.ones(
  92. max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
  93. mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
  94. return (
  95. block_mask_dense_output,
  96. block_mask_dense,
  97. mask_dense,
  98. )
  99. else:
  100. return (
  101. block_mask_dense_output,
  102. block_mask_dense,
  103. None,
  104. )
  105. def binary_mask_to_bias(mask_dense: torch.Tensor):
  106. mask_dense = 1 - mask_dense
  107. mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
  108. return mask_dense
  109. def get_head_sliding_step(n_heads: int,
  110. vert_stride: int,
  111. homo_head: bool = False):
  112. if homo_head:
  113. return 0
  114. return max(1, int(vert_stride / n_heads))
  115. @lru_cache
  116. def get_sparse_attn_mask(
  117. n_heads: int,
  118. q_len: int,
  119. max_seqlen: int,
  120. dtype: torch.dtype,
  121. device: torch.device,
  122. block_size: int = 64,
  123. local_blocks: int = 4,
  124. vert_stride: int = 4,
  125. homo_head: bool = True,
  126. return_dense: bool = False,
  127. dense_mask_type: str = "binary",
  128. ):
  129. """
  130. :param dense_mask_type: "binary" (0 for skip token, 1 for others)
  131. or "bias" (-inf for skip token, 0 or others)
  132. :return: a tuple of 3:
  133. - tuple of crow_indices, col_indices representation
  134. of CSR format.
  135. - block dense mask
  136. - all token dense mask (be aware that it can be OOM if it
  137. is too big) if `return_dense==True`, otherwise, None
  138. """
  139. assert dense_mask_type in ("binary", "bias")
  140. if homo_head:
  141. with torch.no_grad():
  142. (crow, col), block_mask_dense, mask_dense = (
  143. _get_sparse_attn_mask_homo_head(
  144. q_len,
  145. max_seqlen,
  146. dtype,
  147. device,
  148. block_size,
  149. local_blocks,
  150. vert_stride,
  151. return_dense,
  152. ))
  153. crow = crow[None].expand(n_heads, crow.shape[0])
  154. col = col[None].expand(n_heads, col.shape[0])
  155. if return_dense:
  156. mask_dense = mask_dense[None].expand(n_heads,
  157. *mask_dense.shape)
  158. if dense_mask_type == "bias":
  159. mask_dense = binary_mask_to_bias(mask_dense)
  160. return (crow, col), block_mask_dense, mask_dense
  161. with torch.no_grad():
  162. num_blocks = triton.cdiv(max_seqlen, block_size)
  163. q_pos = torch.arange(num_blocks)[None, :, None]
  164. k_pos = torch.arange(num_blocks)[None, None]
  165. head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
  166. mask_vert_strided = [
  167. (torch.arange(num_blocks) + h * head_sliding_step + 1) %
  168. vert_stride == 0 for h in range(n_heads)
  169. ]
  170. mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
  171. block_mask_dense = (((q_pos >= k_pos)
  172. & ((q_pos - k_pos < local_blocks)
  173. | mask_vert_strided)).to(device).to(dtype))
  174. num_blocks_q = triton.cdiv(q_len, block_size)
  175. block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
  176. if return_dense:
  177. mask_dense = torch.kron(
  178. block_mask_dense,
  179. block_mask_dense.new_ones((block_size, block_size)),
  180. )
  181. causal_mask = torch.tril(torch.ones(
  182. max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
  183. mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
  184. if dense_mask_type == "bias":
  185. mask_dense = binary_mask_to_bias(mask_dense)
  186. return (
  187. dense_to_crow_col(block_mask_dense_output),
  188. block_mask_dense,
  189. mask_dense,
  190. )
  191. else:
  192. return (
  193. dense_to_crow_col(block_mask_dense_output),
  194. block_mask_dense,
  195. None,
  196. )