123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- # Helper functions for 3D sparse pattern
- # These function are not optimized and very inefficient.
- # Avoid calling them too frequent or use a cache mechanism.
- from functools import lru_cache
- import numpy as np
- import torch
- import triton
- class csr_matrix:
- """Simple implementation of CSR matrix conversion without scipy.
- This replaced scipy.sparse.csr_matrix() previously used."""
- def __init__(self, input_array):
- if not isinstance(input_array, np.ndarray):
- raise ValueError("Input must be a NumPy array")
- self.shape = input_array.shape
- rows, cols = self.shape
- data = []
- indices = []
- indptr = [0]
- for i in range(rows):
- for j in range(cols):
- if input_array[i, j]:
- data.append(input_array[i, j])
- indices.append(j)
- indptr.append(len(indices))
- self.data = np.array(data)
- self.indices = np.array(indices)
- self.indptr = np.array(indptr)
- def dense_to_crow_col(x: torch.Tensor):
- """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
- NOTE: col_indices padded -1
- """
- device = x.device
- pad = -1
- dim = x.dim()
- assert x.dim() in (2, 3)
- if x.dim() == 2:
- x = x[None]
- x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
- crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
- cols = [torch.from_numpy(xi.indices) for xi in x]
- max_cols = max(len(xi) for xi in cols)
- cols = [
- torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
- for xi in cols
- ]
- cols = torch.vstack(cols)
- if dim == 2:
- crows = crows[0]
- cols = cols[0]
- return crows.to(device), cols.to(device)
- def crow_col_to_dense(crows: torch.Tensor,
- cols: torch.Tensor,
- dtype: torch.dtype = torch.float16):
- dim = crows.dim()
- if dim == 1:
- crows = crows[None]
- cols = cols[None]
- device = crows.device
- crows, cols = crows.cpu(), cols.cpu() # faster in cpu
- shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
- x = torch.zeros(shape, dtype=dtype)
- for i in range(shape[0]):
- for j in range(shape[1]):
- x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
- if dim == 1:
- x = x[0]
- return x.to(device)
- def dense_to_ccol_row(x: torch.Tensor):
- """Similar, but to CSC format"""
- x = x.transpose(-2, -1)
- return dense_to_crow_col(x)
- def ccol_row_to_dense(ccol: torch.Tensor,
- rows: torch.Tensor,
- dtype: torch.dtype = torch.float16):
- return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
- def _get_sparse_attn_mask_homo_head(
- q_len: int,
- max_seqlen: int,
- dtype: torch.dtype,
- device: torch.device,
- block_size: int = 128,
- local_blocks: int = 4,
- vert_stride: int = 4,
- return_dense: bool = False,
- ):
- """
- :return: a tuple of 3:
- - tuple of crow_indices, col_indices representation
- of CSR format.
- - block dense mask
- - all token dense mask (be aware that it can be
- OOM if it is too big) if `return_dense==True`,
- otherwise, None
- """
- with torch.no_grad():
- num_blocks = triton.cdiv(max_seqlen, block_size)
- q_pos = torch.arange(num_blocks)[:, None]
- k_pos = torch.arange(num_blocks)[None]
- mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
- block_mask_dense = (((q_pos >= k_pos)
- & ((q_pos - k_pos < local_blocks)
- | mask_vert_strided)).to(device).to(dtype))
- num_blocks_q = triton.cdiv(q_len, block_size)
- block_mask_dense_output = (dense_to_crow_col(
- block_mask_dense[-num_blocks_q:].contiguous()))
- if return_dense:
- mask_dense = torch.kron(
- block_mask_dense,
- block_mask_dense.new_ones((block_size, block_size)),
- )
- causal_mask = torch.tril(torch.ones(
- max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
- mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
- return (
- block_mask_dense_output,
- block_mask_dense,
- mask_dense,
- )
- else:
- return (
- block_mask_dense_output,
- block_mask_dense,
- None,
- )
- def binary_mask_to_bias(mask_dense: torch.Tensor):
- mask_dense = 1 - mask_dense
- mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
- return mask_dense
- def get_head_sliding_step(n_heads: int,
- vert_stride: int,
- homo_head: bool = False):
- if homo_head:
- return 0
- return max(1, int(vert_stride / n_heads))
- @lru_cache
- def get_sparse_attn_mask(
- n_heads: int,
- q_len: int,
- max_seqlen: int,
- dtype: torch.dtype,
- device: torch.device,
- block_size: int = 64,
- local_blocks: int = 4,
- vert_stride: int = 4,
- homo_head: bool = True,
- return_dense: bool = False,
- dense_mask_type: str = "binary",
- ):
- """
- :param dense_mask_type: "binary" (0 for skip token, 1 for others)
- or "bias" (-inf for skip token, 0 or others)
- :return: a tuple of 3:
- - tuple of crow_indices, col_indices representation
- of CSR format.
- - block dense mask
- - all token dense mask (be aware that it can be OOM if it
- is too big) if `return_dense==True`, otherwise, None
- """
- assert dense_mask_type in ("binary", "bias")
- if homo_head:
- with torch.no_grad():
- (crow, col), block_mask_dense, mask_dense = (
- _get_sparse_attn_mask_homo_head(
- q_len,
- max_seqlen,
- dtype,
- device,
- block_size,
- local_blocks,
- vert_stride,
- return_dense,
- ))
- crow = crow[None].expand(n_heads, crow.shape[0])
- col = col[None].expand(n_heads, col.shape[0])
- if return_dense:
- mask_dense = mask_dense[None].expand(n_heads,
- *mask_dense.shape)
- if dense_mask_type == "bias":
- mask_dense = binary_mask_to_bias(mask_dense)
- return (crow, col), block_mask_dense, mask_dense
- with torch.no_grad():
- num_blocks = triton.cdiv(max_seqlen, block_size)
- q_pos = torch.arange(num_blocks)[None, :, None]
- k_pos = torch.arange(num_blocks)[None, None]
- head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
- mask_vert_strided = [
- (torch.arange(num_blocks) + h * head_sliding_step + 1) %
- vert_stride == 0 for h in range(n_heads)
- ]
- mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
- block_mask_dense = (((q_pos >= k_pos)
- & ((q_pos - k_pos < local_blocks)
- | mask_vert_strided)).to(device).to(dtype))
- num_blocks_q = triton.cdiv(q_len, block_size)
- block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
- if return_dense:
- mask_dense = torch.kron(
- block_mask_dense,
- block_mask_dense.new_ones((block_size, block_size)),
- )
- causal_mask = torch.tril(torch.ones(
- max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
- mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
- if dense_mask_type == "bias":
- mask_dense = binary_mask_to_bias(mask_dense)
- return (
- dense_to_crow_col(block_mask_dense_output),
- block_mask_dense,
- mask_dense,
- )
- else:
- return (
- dense_to_crow_col(block_mask_dense_output),
- block_mask_dense,
- None,
- )
|