utils.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Helper functions for 3D sparse pattern
  2. # These function are not optimized and very inefficient.
  3. # Avoid calling them too frequent or use a cache mechanism.
  4. from functools import lru_cache
  5. import numpy as np
  6. import torch
  7. import triton
  8. class csr_matrix:
  9. """Simple implementation of CSR matrix conversion without scipy.
  10. This replaced scipy.sparse.csr_matrix() previously used."""
  11. def __init__(self, input_array):
  12. if not isinstance(input_array, np.ndarray):
  13. raise ValueError("Input must be a NumPy array")
  14. self.shape = input_array.shape
  15. rows, cols = self.shape
  16. data = []
  17. indices = []
  18. indptr = [0]
  19. for i in range(rows):
  20. for j in range(cols):
  21. if input_array[i, j]:
  22. data.append(input_array[i, j])
  23. indices.append(j)
  24. indptr.append(len(indices))
  25. self.data = np.array(data)
  26. self.indices = np.array(indices)
  27. self.indptr = np.array(indptr)
  28. def dense_to_crow_col(x: torch.Tensor):
  29. """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
  30. NOTE: col_indices padded -1
  31. """
  32. device = x.device
  33. pad = -1
  34. dim = x.dim()
  35. assert x.dim() in (2, 3)
  36. if x.dim() == 2:
  37. x = x[None]
  38. x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
  39. crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
  40. cols = [torch.from_numpy(xi.indices) for xi in x]
  41. max_cols = max(len(xi) for xi in cols)
  42. cols = [
  43. torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
  44. for xi in cols
  45. ]
  46. cols = torch.vstack(cols)
  47. if dim == 2:
  48. crows = crows[0]
  49. cols = cols[0]
  50. return crows.to(device), cols.to(device)
  51. def crow_col_to_dense(crows: torch.Tensor,
  52. cols: torch.Tensor,
  53. dtype: torch.dtype = torch.float16):
  54. dim = crows.dim()
  55. if dim == 1:
  56. crows = crows[None]
  57. cols = cols[None]
  58. device = crows.device
  59. crows, cols = crows.cpu(), cols.cpu() # faster in cpu
  60. shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
  61. x = torch.zeros(shape, dtype=dtype)
  62. for i in range(shape[0]):
  63. for j in range(shape[1]):
  64. x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
  65. if dim == 1:
  66. x = x[0]
  67. return x.to(device)
  68. def dense_to_ccol_row(x: torch.Tensor):
  69. """Similar, but to CSC format"""
  70. x = x.transpose(-2, -1)
  71. return dense_to_crow_col(x)
  72. def ccol_row_to_dense(ccol: torch.Tensor,
  73. rows: torch.Tensor,
  74. dtype: torch.dtype = torch.float16):
  75. return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
  76. def _get_sparse_attn_mask_homo_head(
  77. q_len: int,
  78. max_seqlen: int,
  79. dtype: torch.dtype,
  80. device: torch.device,
  81. block_size: int = 128,
  82. local_blocks: int = 4,
  83. vert_stride: int = 4,
  84. return_dense: bool = False,
  85. ):
  86. """
  87. :return: a tuple of 3:
  88. - tuple of crow_indices, col_indices representation
  89. of CSR format.
  90. - block dense mask
  91. - all token dense mask (be aware that it can be
  92. OOM if it is too big) if `return_dense==True`,
  93. otherwise, None
  94. """
  95. with torch.no_grad():
  96. num_blocks = triton.cdiv(max_seqlen, block_size)
  97. q_pos = torch.arange(num_blocks)[:, None]
  98. k_pos = torch.arange(num_blocks)[None]
  99. mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
  100. block_mask_dense = (((q_pos >= k_pos)
  101. & ((q_pos - k_pos < local_blocks)
  102. | mask_vert_strided)).to(device).to(dtype))
  103. num_blocks_q = triton.cdiv(q_len, block_size)
  104. block_mask_dense_output = (dense_to_crow_col(
  105. block_mask_dense[-num_blocks_q:].contiguous()))
  106. if return_dense:
  107. mask_dense = torch.kron(
  108. block_mask_dense,
  109. block_mask_dense.new_ones((block_size, block_size)),
  110. )
  111. causal_mask = torch.tril(torch.ones(
  112. max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
  113. mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
  114. return (
  115. block_mask_dense_output,
  116. block_mask_dense,
  117. mask_dense,
  118. )
  119. else:
  120. return (
  121. block_mask_dense_output,
  122. block_mask_dense,
  123. None,
  124. )
  125. def binary_mask_to_bias(mask_dense: torch.Tensor):
  126. mask_dense = 1 - mask_dense
  127. mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
  128. return mask_dense
  129. def get_head_sliding_step(n_heads: int,
  130. vert_stride: int,
  131. homo_head: bool = False):
  132. if homo_head:
  133. return 0
  134. return max(1, int(vert_stride / n_heads))
  135. @lru_cache
  136. def get_sparse_attn_mask(
  137. n_heads: int,
  138. q_len: int,
  139. max_seqlen: int,
  140. dtype: torch.dtype,
  141. device: torch.device,
  142. block_size: int = 64,
  143. local_blocks: int = 4,
  144. vert_stride: int = 4,
  145. homo_head: bool = True,
  146. return_dense: bool = False,
  147. dense_mask_type: str = "binary",
  148. ):
  149. """
  150. :param dense_mask_type: "binary" (0 for skip token, 1 for others)
  151. or "bias" (-inf for skip token, 0 or others)
  152. :return: a tuple of 3:
  153. - tuple of crow_indices, col_indices representation
  154. of CSR format.
  155. - block dense mask
  156. - all token dense mask (be aware that it can be OOM if it
  157. is too big) if `return_dense==True`, otherwise, None
  158. """
  159. assert dense_mask_type in ("binary", "bias")
  160. if homo_head:
  161. with torch.no_grad():
  162. (crow, col), block_mask_dense, mask_dense = (
  163. _get_sparse_attn_mask_homo_head(
  164. q_len,
  165. max_seqlen,
  166. dtype,
  167. device,
  168. block_size,
  169. local_blocks,
  170. vert_stride,
  171. return_dense,
  172. ))
  173. crow = crow[None].expand(n_heads, crow.shape[0])
  174. col = col[None].expand(n_heads, col.shape[0])
  175. if return_dense:
  176. mask_dense = mask_dense[None].expand(n_heads,
  177. *mask_dense.shape)
  178. if dense_mask_type == "bias":
  179. mask_dense = binary_mask_to_bias(mask_dense)
  180. return (crow, col), block_mask_dense, mask_dense
  181. with torch.no_grad():
  182. num_blocks = triton.cdiv(max_seqlen, block_size)
  183. q_pos = torch.arange(num_blocks)[None, :, None]
  184. k_pos = torch.arange(num_blocks)[None, None]
  185. head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
  186. mask_vert_strided = [
  187. (torch.arange(num_blocks) + h * head_sliding_step + 1) %
  188. vert_stride == 0 for h in range(n_heads)
  189. ]
  190. mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
  191. block_mask_dense = (((q_pos >= k_pos)
  192. & ((q_pos - k_pos < local_blocks)
  193. | mask_vert_strided)).to(device).to(dtype))
  194. num_blocks_q = triton.cdiv(q_len, block_size)
  195. block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
  196. if return_dense:
  197. mask_dense = torch.kron(
  198. block_mask_dense,
  199. block_mask_dense.new_ones((block_size, block_size)),
  200. )
  201. causal_mask = torch.tril(torch.ones(
  202. max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
  203. mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
  204. if dense_mask_type == "bias":
  205. mask_dense = binary_mask_to_bias(mask_dense)
  206. return (
  207. dense_to_crow_col(block_mask_dense_output),
  208. block_mask_dense,
  209. mask_dense,
  210. )
  211. else:
  212. return (
  213. dense_to_crow_col(block_mask_dense_output),
  214. block_mask_dense,
  215. None,
  216. )