test_util.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import math
  2. import torch
  3. from einops import rearrange, repeat
  4. from padding import pad_input, unpad_input
  5. def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
  6. assert mode in ["full", "random", "third"]
  7. if mode == "full":
  8. lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
  9. elif mode == "random":
  10. lengths = torch.randint(
  11. max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
  12. )
  13. elif mode == "third":
  14. lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
  15. if zero_lengths:
  16. # Generate zero-lengths every 5 batches and the last batch.
  17. for i in range(batch_size):
  18. if i % 5 == 0:
  19. lengths[i] = 0
  20. lengths[-1] = 0
  21. padding_mask = (
  22. repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
  23. )
  24. return padding_mask
  25. def generate_qkv(
  26. q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False,
  27. query_unused_mask=None, key_unused_mask=None,
  28. ):
  29. """
  30. Arguments:
  31. q: (batch_size, seqlen_q, nheads, d)
  32. k: (batch_size, seqlen_k, nheads_k, d)
  33. v: (batch_size, seqlen_k, nheads_k, d)
  34. query_padding_mask: (batch_size, seqlen), bool
  35. key_padding_mask: (batch_size, seqlen), bool
  36. """
  37. assert not (kvpacked and qkvpacked)
  38. batch_size, seqlen_q, nheads, d = q.shape
  39. _, seqlen_k, nheads_k, _ = k.shape
  40. assert k.shape == (batch_size, seqlen_k, nheads_k, d)
  41. assert v.shape == (batch_size, seqlen_k, nheads_k, d)
  42. if query_unused_mask is not None or key_unused_mask is not None:
  43. assert not kvpacked
  44. assert not qkvpacked
  45. if query_padding_mask is not None:
  46. q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
  47. q, query_padding_mask, query_unused_mask
  48. )
  49. output_pad_fn = lambda output_unpad: pad_input(
  50. output_unpad, indices_q, batch_size, seqlen_q
  51. )
  52. else:
  53. q_unpad = rearrange(q, "b s h d -> (b s) h d")
  54. cu_seqlens_q = torch.arange(
  55. 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
  56. )
  57. seqused_q = None
  58. max_seqlen_q = seqlen_q
  59. output_pad_fn = lambda output_unpad: rearrange(
  60. output_unpad, "(b s) h d -> b s h d", b=batch_size
  61. )
  62. if key_padding_mask is not None:
  63. k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
  64. k, key_padding_mask, key_unused_mask
  65. )
  66. v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)
  67. else:
  68. k_unpad = rearrange(k, "b s h d -> (b s) h d")
  69. v_unpad = rearrange(v, "b s h d -> (b s) h d")
  70. cu_seqlens_k = torch.arange(
  71. 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
  72. )
  73. seqused_k = None
  74. max_seqlen_k = seqlen_k
  75. if qkvpacked:
  76. assert (query_padding_mask == key_padding_mask).all()
  77. assert nheads == nheads_k
  78. qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
  79. qkv = torch.stack([q, k, v], dim=2)
  80. if query_padding_mask is not None:
  81. dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
  82. else:
  83. dqkv_pad_fn = lambda dqkv_unpad: rearrange(
  84. dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  85. )
  86. return (
  87. qkv_unpad.detach().requires_grad_(),
  88. cu_seqlens_q,
  89. max_seqlen_q,
  90. qkv.detach().requires_grad_(),
  91. output_pad_fn,
  92. dqkv_pad_fn,
  93. )
  94. elif kvpacked:
  95. kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
  96. kv = torch.stack([k, v], dim=2)
  97. dq_pad_fn = output_pad_fn
  98. if key_padding_mask is not None:
  99. dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
  100. else:
  101. dkv_pad_fn = lambda dkv_unpad: rearrange(
  102. dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  103. )
  104. return (
  105. q_unpad.detach().requires_grad_(),
  106. kv_unpad.detach().requires_grad_(),
  107. cu_seqlens_q,
  108. cu_seqlens_k,
  109. max_seqlen_q,
  110. max_seqlen_k,
  111. q.detach().requires_grad_(),
  112. kv.detach().requires_grad_(),
  113. output_pad_fn,
  114. dq_pad_fn,
  115. dkv_pad_fn,
  116. )
  117. else:
  118. dq_pad_fn = output_pad_fn
  119. if key_padding_mask is not None:
  120. dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
  121. else:
  122. dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
  123. return (
  124. q_unpad.detach().requires_grad_(),
  125. k_unpad.detach().requires_grad_(),
  126. v_unpad.detach().requires_grad_(),
  127. cu_seqlens_q,
  128. cu_seqlens_k,
  129. seqused_q,
  130. seqused_k,
  131. max_seqlen_q,
  132. max_seqlen_k,
  133. q.detach().requires_grad_(),
  134. k.detach().requires_grad_(),
  135. v.detach().requires_grad_(),
  136. output_pad_fn,
  137. dq_pad_fn,
  138. dk_pad_fn,
  139. )
  140. def construct_local_mask(
  141. seqlen_q,
  142. seqlen_k,
  143. window_size=(-1, -1), # -1 means infinite window size
  144. sink_token_length=0,
  145. query_padding_mask=None,
  146. key_padding_mask=None,
  147. key_leftpad=None,
  148. device=None,
  149. ):
  150. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  151. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  152. if key_leftpad is not None:
  153. key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
  154. col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
  155. col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
  156. sk = (
  157. seqlen_k
  158. if key_padding_mask is None
  159. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  160. )
  161. sq = (
  162. seqlen_q
  163. if query_padding_mask is None
  164. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  165. )
  166. if window_size[0] < 0:
  167. return col_idx > row_idx + sk - sq + window_size[1]
  168. else:
  169. sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
  170. return torch.logical_or(
  171. col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
  172. torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length),
  173. )
  174. def attention_ref(
  175. q,
  176. k,
  177. v,
  178. query_padding_mask=None,
  179. key_padding_mask=None,
  180. key_leftpad=None,
  181. attn_bias=None,
  182. dropout_p=0.0,
  183. dropout_mask=None,
  184. causal=False,
  185. q_descale=None, k_descale=None, v_descale=None,
  186. window_size=(-1, -1), # -1 means infinite window size
  187. sink_token_length=0,
  188. softcap=0.0,
  189. upcast=True,
  190. reorder_ops=False,
  191. intermediate_dtype=None,
  192. ):
  193. """
  194. Arguments:
  195. q: (batch_size, seqlen_q, nheads, head_dim)
  196. k: (batch_size, seqlen_k, nheads, head_dim)
  197. v: (batch_size, seqlen_k, nheads, head_dim)
  198. query_padding_mask: (batch_size, seqlen_q)
  199. key_padding_mask: (batch_size, seqlen_k)
  200. attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
  201. dropout_p: float
  202. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
  203. causal: whether to apply causal masking
  204. upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
  205. output back to fp16/bf16.
  206. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
  207. without changing the math. This is to estimate the numerical error from operation
  208. reordering.
  209. Output:
  210. output: (batch_size, seqlen_q, nheads, head_dim)
  211. attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
  212. """
  213. if causal:
  214. window_size = (window_size[0], 0)
  215. dtype_og = q.dtype
  216. if upcast:
  217. q, k, v = q.float(), k.float(), v.float()
  218. if q_descale is not None:
  219. q_descale = repeat(q_descale, "b h -> b (h g)", g = q.shape[2] // k.shape[2])
  220. q = (q.float() * rearrange(q_descale, "b h -> b 1 h 1")).to(dtype=q.dtype)
  221. if k_descale is not None:
  222. k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
  223. if v_descale is not None:
  224. v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
  225. seqlen_q, seqlen_k = q.shape[1], k.shape[1]
  226. k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
  227. v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
  228. d = q.shape[-1]
  229. if not reorder_ops:
  230. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  231. else:
  232. scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
  233. if softcap > 0:
  234. scores = torch.tanh(scores / softcap) * softcap
  235. if key_padding_mask is not None:
  236. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  237. if window_size[0] >= 0 or window_size[1] >= 0:
  238. local_mask = construct_local_mask(
  239. seqlen_q,
  240. seqlen_k,
  241. window_size,
  242. sink_token_length,
  243. query_padding_mask,
  244. key_padding_mask,
  245. key_leftpad=key_leftpad,
  246. device=q.device,
  247. )
  248. scores.masked_fill_(local_mask, float("-inf"))
  249. if attn_bias is not None:
  250. scores = scores + attn_bias
  251. attention = torch.softmax(scores, dim=-1).to(v.dtype)
  252. # We want to mask here so that the attention matrix doesn't have any NaNs
  253. # Otherwise we'll get NaN in dV
  254. if query_padding_mask is not None:
  255. attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  256. # Without this we might get NaN in dv
  257. if key_padding_mask is not None:
  258. attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
  259. # Some rows might be completely masked out so we fill them with zero instead of NaN
  260. if window_size[0] >= 0 or window_size[1] >= 0:
  261. attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
  262. dropout_scaling = 1.0 / (1 - dropout_p)
  263. # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
  264. # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  265. if dropout_mask is not None:
  266. attention_drop = attention.masked_fill(~dropout_mask, 0.0)
  267. else:
  268. attention_drop = attention
  269. if intermediate_dtype is not None:
  270. attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
  271. output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
  272. if query_padding_mask is not None:
  273. output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
  274. return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)