test_util.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import math
  2. import torch
  3. from einops import rearrange, repeat
  4. from flash_attn.bert_padding import pad_input, unpad_input
  5. def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
  6. assert mode in ["full", "random", "third"]
  7. if mode == "full":
  8. lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
  9. elif mode == "random":
  10. lengths = torch.randint(
  11. max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
  12. )
  13. elif mode == "third":
  14. lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
  15. padding_mask = (
  16. repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
  17. )
  18. return padding_mask
  19. def generate_qkv(
  20. q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
  21. ):
  22. """
  23. Arguments:
  24. q: (batch_size, seqlen_q, nheads, d)
  25. k: (batch_size, seqlen_k, nheads_k, d)
  26. v: (batch_size, seqlen_k, nheads_k, d)
  27. query_padding_mask: (batch_size, seqlen), bool
  28. key_padding_mask: (batch_size, seqlen), bool
  29. """
  30. assert not (kvpacked and qkvpacked)
  31. batch_size, seqlen_q, nheads, d = q.shape
  32. _, seqlen_k, nheads_k, _ = k.shape
  33. assert k.shape == (batch_size, seqlen_k, nheads_k, d)
  34. assert v.shape == (batch_size, seqlen_k, nheads_k, d)
  35. if query_padding_mask is not None:
  36. q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
  37. output_pad_fn = lambda output_unpad: pad_input(
  38. output_unpad, indices_q, batch_size, seqlen_q
  39. )
  40. else:
  41. q_unpad = rearrange(q, "b s h d -> (b s) h d")
  42. cu_seqlens_q = torch.arange(
  43. 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
  44. )
  45. max_seqlen_q = seqlen_q
  46. output_pad_fn = lambda output_unpad: rearrange(
  47. output_unpad, "(b s) h d -> b s h d", b=batch_size
  48. )
  49. if key_padding_mask is not None:
  50. k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
  51. v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
  52. else:
  53. k_unpad = rearrange(k, "b s h d -> (b s) h d")
  54. v_unpad = rearrange(v, "b s h d -> (b s) h d")
  55. cu_seqlens_k = torch.arange(
  56. 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
  57. )
  58. max_seqlen_k = seqlen_k
  59. if qkvpacked:
  60. assert (query_padding_mask == key_padding_mask).all()
  61. assert nheads == nheads_k
  62. qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
  63. qkv = torch.stack([q, k, v], dim=2)
  64. if query_padding_mask is not None:
  65. dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
  66. else:
  67. dqkv_pad_fn = lambda dqkv_unpad: rearrange(
  68. dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  69. )
  70. return (
  71. qkv_unpad.detach().requires_grad_(),
  72. cu_seqlens_q,
  73. max_seqlen_q,
  74. qkv.detach().requires_grad_(),
  75. output_pad_fn,
  76. dqkv_pad_fn,
  77. )
  78. elif kvpacked:
  79. kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
  80. kv = torch.stack([k, v], dim=2)
  81. dq_pad_fn = output_pad_fn
  82. if key_padding_mask is not None:
  83. dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
  84. else:
  85. dkv_pad_fn = lambda dkv_unpad: rearrange(
  86. dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  87. )
  88. return (
  89. q_unpad.detach().requires_grad_(),
  90. kv_unpad.detach().requires_grad_(),
  91. cu_seqlens_q,
  92. cu_seqlens_k,
  93. max_seqlen_q,
  94. max_seqlen_k,
  95. q.detach().requires_grad_(),
  96. kv.detach().requires_grad_(),
  97. output_pad_fn,
  98. dq_pad_fn,
  99. dkv_pad_fn,
  100. )
  101. else:
  102. dq_pad_fn = output_pad_fn
  103. if key_padding_mask is not None:
  104. dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
  105. else:
  106. dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
  107. return (
  108. q_unpad.detach().requires_grad_(),
  109. k_unpad.detach().requires_grad_(),
  110. v_unpad.detach().requires_grad_(),
  111. cu_seqlens_q,
  112. cu_seqlens_k,
  113. max_seqlen_q,
  114. max_seqlen_k,
  115. q.detach().requires_grad_(),
  116. k.detach().requires_grad_(),
  117. v.detach().requires_grad_(),
  118. output_pad_fn,
  119. dq_pad_fn,
  120. dk_pad_fn,
  121. )
  122. def construct_local_mask(
  123. seqlen_q,
  124. seqlen_k,
  125. window_size=(-1, -1), # -1 means infinite window size
  126. query_padding_mask=None,
  127. key_padding_mask=None,
  128. device=None,
  129. key_leftpad=None,
  130. ):
  131. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  132. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  133. if key_leftpad is not None:
  134. key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
  135. col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
  136. col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
  137. sk = (
  138. seqlen_k
  139. if key_padding_mask is None
  140. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  141. )
  142. sq = (
  143. seqlen_q
  144. if query_padding_mask is None
  145. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  146. )
  147. if window_size[0] < 0:
  148. return col_idx > row_idx + sk - sq + window_size[1]
  149. else:
  150. sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
  151. return torch.logical_or(
  152. col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
  153. col_idx < row_idx + sk - sq - window_size[0],
  154. )
  155. def attention_ref(
  156. q,
  157. k,
  158. v,
  159. query_padding_mask=None,
  160. key_padding_mask=None,
  161. attn_bias=None,
  162. dropout_p=0.0,
  163. dropout_mask=None,
  164. causal=False,
  165. window_size=(-1, -1), # -1 means infinite window size
  166. softcap=0.0,
  167. upcast=True,
  168. reorder_ops=False,
  169. key_leftpad=None,
  170. ):
  171. """
  172. Arguments:
  173. q: (batch_size, seqlen_q, nheads, head_dim)
  174. k: (batch_size, seqlen_k, nheads_k, head_dim)
  175. v: (batch_size, seqlen_k, nheads_k, head_dim)
  176. query_padding_mask: (batch_size, seqlen_q)
  177. key_padding_mask: (batch_size, seqlen_k)
  178. attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
  179. dropout_p: float
  180. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
  181. causal: whether to apply causal masking
  182. window_size: (int, int), left and right window size
  183. upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
  184. output back to fp16/bf16.
  185. reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
  186. without changing the math. This is to estimate the numerical error from operation
  187. reordering.
  188. Output:
  189. output: (batch_size, seqlen_q, nheads, head_dim)
  190. attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
  191. """
  192. if causal:
  193. window_size = (window_size[0], 0)
  194. dtype_og = q.dtype
  195. if upcast:
  196. q, k, v = q.float(), k.float(), v.float()
  197. seqlen_q, seqlen_k = q.shape[1], k.shape[1]
  198. k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
  199. v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
  200. d = q.shape[-1]
  201. if not reorder_ops:
  202. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  203. else:
  204. scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
  205. if softcap > 0:
  206. scores /= softcap
  207. scores = scores.tanh()
  208. scores *= softcap
  209. if key_padding_mask is not None:
  210. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  211. if window_size[0] >= 0 or window_size[1] >= 0:
  212. local_mask = construct_local_mask(
  213. seqlen_q,
  214. seqlen_k,
  215. window_size,
  216. query_padding_mask,
  217. key_padding_mask,
  218. q.device,
  219. key_leftpad=key_leftpad,
  220. )
  221. scores.masked_fill_(local_mask, float("-inf"))
  222. if attn_bias is not None:
  223. scores = scores + attn_bias
  224. attention = torch.softmax(scores, dim=-1).to(v.dtype)
  225. # Some rows might be completely masked out so we fill them with zero instead of NaN
  226. if window_size[0] >= 0 or window_size[1] >= 0:
  227. attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
  228. # We want to mask here so that the attention matrix doesn't have any NaNs
  229. # Otherwise we'll get NaN in dV
  230. if query_padding_mask is not None:
  231. attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  232. dropout_scaling = 1.0 / (1 - dropout_p)
  233. # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
  234. # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  235. if dropout_mask is not None:
  236. attention_drop = attention.masked_fill(~dropout_mask, 0.0)
  237. else:
  238. attention_drop = attention
  239. output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
  240. if query_padding_mask is not None:
  241. output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
  242. return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)