fwd_ref.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import torch
  2. import math
  3. from .utils import DEBUG
  4. def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2):
  5. if DEBUG:
  6. print()
  7. print("attention_forward_core_ref_impl")
  8. print("q:", q, q.shape)
  9. print("k:", k, k.shape)
  10. print("v:", v, v.shape)
  11. print("sm_scale:", sm_scale)
  12. print("causal:", causal)
  13. print("use_exp2:", use_exp2)
  14. # Compute attention scores
  15. attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
  16. if DEBUG:
  17. print("attention_scores:", attention_scores, attention_scores.shape)
  18. # Scale scores
  19. attention_scaled_scores = sm_scale * attention_scores
  20. if DEBUG:
  21. print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape)
  22. # Apply causal mask if necessary
  23. if causal:
  24. L_q, L_k = q.shape[1], k.shape[1]
  25. row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
  26. col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
  27. col_offset = L_q-L_k
  28. causal_mask = row_idx >= (col_offset + col_idx)
  29. if DEBUG:
  30. print("causal_mask:", causal_mask)
  31. # set -inf to places the causal mask is false
  32. attention_scaled_scores = attention_scaled_scores.masked_fill(
  33. torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
  34. )
  35. if DEBUG:
  36. print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape)
  37. # Compute max for numerical stability
  38. max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0]
  39. if DEBUG:
  40. print("max_scores:", max_scores, max_scores.shape)
  41. if causal:
  42. # Replace -inf in max_scores with zeros to avoid NaN in subtraction
  43. max_scores = torch.where(
  44. torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores
  45. )
  46. if DEBUG:
  47. print("max_scores if causal:", max_scores, max_scores.shape)
  48. # Shift scores
  49. attention_shifted_scaled_scores = attention_scaled_scores - max_scores
  50. if DEBUG:
  51. print("attention_shifted_scaled_scores:", attention_shifted_scaled_scores, attention_shifted_scaled_scores.shape)
  52. # Exponentiate
  53. if use_exp2:
  54. RCP_LN = 1 / math.log(2)
  55. exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores)
  56. else:
  57. exp_scores = torch.exp(attention_shifted_scaled_scores)
  58. if DEBUG:
  59. print("exp_scores:", exp_scores, exp_scores.shape)
  60. # Sum of exponentials
  61. sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True)
  62. if DEBUG:
  63. print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape)
  64. if causal:
  65. # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly
  66. sum_exp_scores = torch.where(
  67. sum_exp_scores == 0,
  68. torch.ones_like(sum_exp_scores),
  69. sum_exp_scores
  70. )
  71. if DEBUG:
  72. print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape)
  73. # Compute softmax probabilities
  74. softmax = exp_scores / sum_exp_scores
  75. if DEBUG:
  76. print("softmax:", softmax, softmax.shape)
  77. # Compute log-sum-exp
  78. if use_exp2:
  79. LN2 = math.log(2)
  80. RCP_LN = 1 / math.log(2)
  81. max_scores_base2 = max_scores * RCP_LN
  82. softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores)
  83. softmax_lse = softmax_lse_base2 * LN2
  84. softmax_lse.squeeze_(-1)
  85. else:
  86. softmax_lse = max_scores + torch.log(sum_exp_scores)
  87. softmax_lse = softmax_lse.squeeze(-1)
  88. if DEBUG:
  89. print("softmax_lse:", softmax_lse, softmax_lse.shape)
  90. # Compute output
  91. o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16)
  92. if DEBUG:
  93. print("o:", o, o.shape)
  94. return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
  95. def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2):
  96. """Compute reference output and softmax_lse using PyTorch's built-in function"""
  97. # Ensure the layout is 'bhsd'
  98. if layout == "bshd":
  99. q = q.transpose(1, 2).contiguous()
  100. k = k.transpose(1, 2).contiguous()
  101. v = v.transpose(1, 2).contiguous()
  102. elif layout != "bhsd":
  103. raise ValueError(f"Unknown layout {layout}")
  104. # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
  105. batch_size, num_heads, seq_len_q, head_dim = q.shape
  106. seq_len_k = k.shape[2]
  107. # Merge batch and heads dimensions
  108. q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
  109. k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
  110. v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
  111. # Call the core attention function
  112. o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores = attention_forward_core_ref_impl(
  113. q, k, v, sm_scale, causal, use_exp2
  114. )
  115. # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
  116. o = o.reshape(batch_size, num_heads, seq_len_q, head_dim)
  117. softmax_lse = softmax_lse.reshape(batch_size, num_heads, seq_len_q)
  118. exp_scores = exp_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
  119. softmax = softmax.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
  120. attention_shifted_scaled_scores = attention_shifted_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
  121. attention_scaled_scores = attention_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
  122. attention_scores = attention_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
  123. # Restore original layout if necessary
  124. if layout == "bshd":
  125. o = o.transpose(1, 2)
  126. return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
  127. def attention_varlen_forward_pytorch_ref_impl(
  128. q,
  129. k,
  130. v,
  131. sm_scale,
  132. causal,
  133. layout,
  134. cu_seqlens_q,
  135. cu_seqlens_k,
  136. max_seqlen_q,
  137. max_seqlen_k,
  138. use_exp2
  139. ):
  140. # Ensure the layout is 'thd'
  141. if layout != 'thd':
  142. raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
  143. batch_size = cu_seqlens_q.shape[0] - 1
  144. num_heads = q.shape[1]
  145. head_dim = q.shape[2]
  146. # Pre-allocate outputs
  147. total_L_q = q.shape[0]
  148. total_L_k = k.shape[0]
  149. o = torch.empty((total_L_q, num_heads, head_dim), dtype=q.dtype, device=q.device)
  150. softmax_lse = torch.empty((total_L_q, num_heads), dtype=torch.float32, device=q.device)
  151. for i in range(batch_size):
  152. # Get the start and end indices for the current sequence
  153. start_q = cu_seqlens_q[i].item()
  154. end_q = cu_seqlens_q[i + 1].item()
  155. start_k = cu_seqlens_k[i].item()
  156. end_k = cu_seqlens_k[i + 1].item()
  157. # Extract q_i, k_i, v_i
  158. q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
  159. k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
  160. v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
  161. # Permute to [num_heads, L_q_i, head_dim]
  162. q_i = q_i.permute(1, 0, 2)
  163. k_i = k_i.permute(1, 0, 2)
  164. v_i = v_i.permute(1, 0, 2)
  165. # Call the core attention function for this sequence
  166. (
  167. o_i,
  168. softmax_lse_i,
  169. exp_scores_i,
  170. softmax_i,
  171. attention_shifted_scaled_scores_i,
  172. attention_scaled_scores_i,
  173. attention_scores_i,
  174. ) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, use_exp2)
  175. # Convert back to 'thd' layout and float16
  176. o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, num_heads, head_dim]
  177. # Place outputs in pre-allocated tensors
  178. o[start_q:end_q, :, :] = o_i
  179. softmax_lse[start_q:end_q, :] = softmax_lse_i.transpose(0, 1) # Transpose to [L_q_i, num_heads]
  180. # For variable-sized outputs, map them into the preallocated tensors
  181. # exp_scores_i: [num_heads, L_q_i, L_k_i] -> [L_q_i, num_heads, L_k_i]
  182. exp_scores_i = exp_scores_i.permute(1, 0, 2)
  183. softmax_i = softmax_i.permute(1, 0, 2)
  184. attention_shifted_scaled_scores_i = attention_shifted_scaled_scores_i.permute(1, 0, 2)
  185. attention_scaled_scores_i = attention_scaled_scores_i.permute(1, 0, 2)
  186. attention_scores_i = attention_scores_i.permute(1, 0, 2)
  187. return (
  188. o,
  189. softmax_lse,
  190. None,
  191. None,
  192. None,
  193. None,
  194. None,
  195. )
  196. def attention_forward_pytorch_ref_impl(
  197. q,
  198. k,
  199. v,
  200. sm_scale,
  201. causal,
  202. layout,
  203. cu_seqlens_q,
  204. cu_seqlens_k,
  205. max_seqlen_q,
  206. max_seqlen_k,
  207. use_exp2
  208. ):
  209. if DEBUG:
  210. print()
  211. print("attention_forward_pytorch_ref_impl")
  212. print("q:", q, q.shape)
  213. print("k:", k, k.shape)
  214. print("v:", v, v.shape)
  215. print("sm_scale:", sm_scale)
  216. print("causal:", causal)
  217. print("cu_seqlens_q:", cu_seqlens_q)
  218. print("cu_seqlens_k:", cu_seqlens_k)
  219. print("max_seqlen_q:", max_seqlen_q)
  220. print("max_seqlen_k:", max_seqlen_k)
  221. print("use_exp2:", use_exp2)
  222. # compute reference
  223. if layout == "thd":
  224. (
  225. o_ref,
  226. softmax_lse_ref,
  227. exp_scores_ref,
  228. softmax_ref,
  229. attention_shifted_scaled_scores_ref,
  230. attention_scaled_scores_ref,
  231. attention_scores_ref,
  232. ) = attention_varlen_forward_pytorch_ref_impl(
  233. q.clone(),
  234. k.clone(),
  235. v.clone(),
  236. sm_scale,
  237. causal,
  238. layout,
  239. cu_seqlens_q,
  240. cu_seqlens_k,
  241. max_seqlen_q,
  242. max_seqlen_k,
  243. use_exp2,
  244. )
  245. else:
  246. (
  247. o_ref,
  248. softmax_lse_ref,
  249. exp_scores_ref,
  250. softmax_ref,
  251. attention_shifted_scaled_scores_ref,
  252. attention_scaled_scores_ref,
  253. attention_scores_ref,
  254. ) = attention_vanilla_forward_pytorch_ref_impl(
  255. q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2
  256. )
  257. if DEBUG:
  258. print()
  259. print("attention_forward_pytorch_ref_impl outputs")
  260. print("o_ref:", o_ref, o_ref.shape)
  261. print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape)
  262. print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape if exp_scores_ref is not None else None)
  263. return (
  264. o_ref,
  265. softmax_lse_ref,
  266. exp_scores_ref,
  267. softmax_ref,
  268. attention_shifted_scaled_scores_ref,
  269. attention_scaled_scores_ref,
  270. attention_scores_ref,
  271. )
  272. def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
  273. q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
  274. k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
  275. relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K)
  276. return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)