bwd_ref.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import torch
  2. import math
  3. from .utils import DEBUG
  4. def attention_backward_core_ref_impl(
  5. do, q, k, v, o, softmax_lse, sm_scale, causal, use_exp2
  6. ):
  7. if DEBUG:
  8. print()
  9. print("attention_backward_core_ref_impl")
  10. print("do:", do, do.shape)
  11. print("q:", q, q.shape)
  12. print("k:", k, k.shape)
  13. print("v:", v, v.shape)
  14. print("o:", o, o.shape) # is a bad number
  15. print("softmax_lse:", softmax_lse, softmax_lse.shape)
  16. print("sm_scale:", sm_scale)
  17. print("causal:", causal)
  18. print("use_exp2:", use_exp2)
  19. # cast to float32
  20. do = do.to(torch.float32)
  21. q = q.to(torch.float32)
  22. k = k.to(torch.float32)
  23. v = v.to(torch.float32)
  24. o = o.to(torch.float32)
  25. softmax_lse = softmax_lse.to(torch.float32)
  26. # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32
  27. attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
  28. if DEBUG:
  29. print("attention_scores:", attention_scores, attention_scores.shape)
  30. # scale scores
  31. attention_scaled_scores = sm_scale * attention_scores
  32. if DEBUG:
  33. print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape)
  34. # Apply causal mask if necessary
  35. if causal:
  36. L_q, L_k = q.shape[1], k.shape[1]
  37. row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
  38. col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
  39. col_offset = L_q-L_k
  40. causal_mask = row_idx >= (col_offset + col_idx)
  41. if DEBUG:
  42. print("causal_mask:", causal_mask)
  43. # set -inf to places the causal mask is false
  44. attention_scaled_scores = attention_scaled_scores.masked_fill(
  45. torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
  46. )
  47. if DEBUG:
  48. print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape)
  49. # compute probabilities using softmax_lse
  50. if use_exp2:
  51. RCP_LN = 1 / math.log(2)
  52. attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN
  53. softmax_lse_base2 = softmax_lse * RCP_LN
  54. softmax_lse_3d = softmax_lse_base2.unsqueeze(-1)
  55. p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d)
  56. else:
  57. softmax_lse_3d = softmax_lse.unsqueeze(-1)
  58. p = torch.exp(attention_scaled_scores - softmax_lse_3d)
  59. if DEBUG:
  60. print("softmax_lse_3d:", softmax_lse_3d, softmax_lse_3d.shape)
  61. print("p:", p, p.shape)
  62. # compute gradient wrt v
  63. dv = torch.matmul(p.transpose(-2, -1), do.to(torch.float32))
  64. if DEBUG:
  65. print("dv:", dv, dv.shape)
  66. # compute dp
  67. dp = torch.matmul(do, v.transpose(-2, -1))
  68. if DEBUG:
  69. print("dp:", dp, dp.shape)
  70. # calculate ds using dp
  71. if True:
  72. delta = torch.sum(o * do, axis=-1).to(torch.float32) # what OAI kernel uses
  73. delta_3d = delta.unsqueeze(-1)
  74. else:
  75. delta = torch.sum(p * dp, axis=-1) # what the math says you should use
  76. delta_3d = delta.unsqueeze(-1)
  77. if DEBUG:
  78. print("delta_3d:", delta_3d, delta_3d.shape)
  79. ds = (p * (dp - delta_3d)) * sm_scale
  80. if DEBUG:
  81. print("ds:", ds, ds.shape)
  82. # compute gradient wrt k
  83. dk = torch.matmul(ds.transpose(-2, -1), q.to(torch.float32))
  84. if DEBUG:
  85. print("dk:", dk, dk.shape)
  86. # compute gradient wrt q
  87. dq = torch.matmul(ds, k.to(torch.float32))
  88. if DEBUG:
  89. print("dq:", dq, dq.shape)
  90. # cast back to original dtype
  91. dq = dq.to(torch.float16)
  92. dk = dk.to(torch.float16)
  93. dv = dv.to(torch.float16)
  94. # remove d dim with size 1
  95. delta = delta_3d.squeeze(-1)
  96. if DEBUG:
  97. print("attention_backward_core_ref_impl output")
  98. print("dq:", dq, dq.shape)
  99. print("dk:", dk, dk.shape)
  100. print("dv:", dv, dv.shape)
  101. print("delta:", delta, delta.shape)
  102. return dq, dk, dv, delta
  103. def attention_varlen_backward_pytorch_ref_impl(
  104. do,
  105. q,
  106. k,
  107. v,
  108. o,
  109. softmax_lse,
  110. sm_scale,
  111. causal,
  112. layout,
  113. cu_seqlens_q,
  114. cu_seqlens_k,
  115. max_seqlen_q,
  116. max_seqlen_k,
  117. use_exp2,
  118. ):
  119. # Ensure the layout is 'thd'
  120. if layout != 'thd':
  121. raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
  122. batch_size = cu_seqlens_q.shape[0] - 1
  123. num_heads = q.shape[1]
  124. head_dim = q.shape[2]
  125. # Pre-allocate outputs
  126. total_L_q = q.shape[0]
  127. total_L_k = k.shape[0]
  128. dq = torch.zeros_like(q)
  129. dk = torch.zeros_like(k)
  130. dv = torch.zeros_like(v)
  131. # delta has the same shape as softmax_lse: [total_L_q, num_heads]
  132. delta = torch.zeros((total_L_q, num_heads), dtype=torch.float32, device=o.device)
  133. for i in range(batch_size):
  134. # Get the start and end indices for the current sequence
  135. start_q = cu_seqlens_q[i].item()
  136. end_q = cu_seqlens_q[i + 1].item()
  137. start_k = cu_seqlens_k[i].item()
  138. end_k = cu_seqlens_k[i + 1].item()
  139. # Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i
  140. q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
  141. k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
  142. v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
  143. do_i = do[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
  144. o_i = o[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
  145. # softmax_lse has shape [total_L_q, num_heads]
  146. softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, num_heads]
  147. softmax_lse_i = softmax_lse_i.transpose(0, 1) # [num_heads, L_q_i]
  148. # Permute to [num_heads, L_q_i, head_dim]
  149. q_i = q_i.permute(1, 0, 2)
  150. k_i = k_i.permute(1, 0, 2)
  151. v_i = v_i.permute(1, 0, 2)
  152. do_i = do_i.permute(1, 0, 2)
  153. o_i = o_i.permute(1, 0, 2)
  154. # softmax_lse_i is already in [num_heads, L_q_i]
  155. # Call the core backward function for this sequence
  156. dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl(
  157. do_i,
  158. q_i,
  159. k_i,
  160. v_i,
  161. o_i,
  162. softmax_lse_i,
  163. sm_scale,
  164. causal,
  165. use_exp2
  166. )
  167. # Convert back to 'thd' layout
  168. dq_i = dq_i.permute(1, 0, 2) # [L_q_i, num_heads, head_dim]
  169. dk_i = dk_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim]
  170. dv_i = dv_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim]
  171. # Place outputs in pre-allocated tensors
  172. dq[start_q:end_q, :, :] = dq_i
  173. dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys
  174. dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values
  175. # delta_i has shape [num_heads, L_q_i]
  176. delta_i = delta_i.transpose(1, 0) # [L_q_i, num_heads]
  177. delta[start_q:end_q, :] = delta_i
  178. return dq, dk, dv, delta
  179. def attention_vanilla_backward_pytorch_ref_impl(
  180. do,
  181. q,
  182. k,
  183. v,
  184. o,
  185. softmax_lse,
  186. sm_scale,
  187. causal,
  188. layout,
  189. use_exp2,
  190. ):
  191. if layout == "bshd":
  192. if DEBUG:
  193. print()
  194. print("Changing layout to bhsd!")
  195. do = do.transpose(1, 2).contiguous()
  196. q = q.transpose(1, 2).contiguous()
  197. k = k.transpose(1, 2).contiguous()
  198. v = v.transpose(1, 2).contiguous()
  199. o = o.transpose(1, 2).contiguous()
  200. elif layout == "bhsd":
  201. pass
  202. else:
  203. raise ValueError(f"Unknown layout {layout}")
  204. # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
  205. batch_size, num_heads, seq_len_q, head_dim = q.shape
  206. seq_len_k = k.shape[2]
  207. # Merge batch and heads dimensions
  208. do = do.reshape(batch_size * num_heads, seq_len_q, head_dim)
  209. q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
  210. k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
  211. v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
  212. softmax_lse = softmax_lse.reshape(batch_size * num_heads, seq_len_q)
  213. o = o.reshape(batch_size * num_heads, seq_len_q, head_dim)
  214. dq, dk, dv, delta = attention_backward_core_ref_impl(
  215. do,
  216. q,
  217. k,
  218. v,
  219. o,
  220. softmax_lse,
  221. sm_scale,
  222. causal,
  223. use_exp2
  224. )
  225. # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
  226. dq = dq.reshape(batch_size, num_heads, seq_len_q, head_dim)
  227. dk = dk.reshape(batch_size, num_heads, seq_len_k, head_dim)
  228. dv = dv.reshape(batch_size, num_heads, seq_len_k, head_dim)
  229. delta = delta.reshape(batch_size, num_heads, seq_len_q)
  230. # Go back to original layout
  231. if layout == "bshd":
  232. if DEBUG:
  233. print()
  234. print("Changing back to bshd!")
  235. dq = dq.transpose(1, 2)
  236. dk = dk.transpose(1, 2)
  237. dv = dv.transpose(1, 2)
  238. elif layout == "bhsd":
  239. pass
  240. else:
  241. raise ValueError(f"Unknown layout {layout}")
  242. return dq, dk, dv, delta
  243. def attention_backward_pytorch_ref_impl(
  244. do,
  245. q,
  246. k,
  247. v,
  248. o,
  249. softmax_lse,
  250. sm_scale,
  251. causal,
  252. layout,
  253. cu_seqlens_q,
  254. cu_seqlens_k,
  255. max_seqlen_q,
  256. max_seqlen_k,
  257. use_exp2
  258. ):
  259. if layout == "thd":
  260. dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl(
  261. do,
  262. q,
  263. k,
  264. v,
  265. o,
  266. softmax_lse,
  267. sm_scale,
  268. causal,
  269. layout,
  270. cu_seqlens_q,
  271. cu_seqlens_k,
  272. max_seqlen_q,
  273. max_seqlen_k,
  274. use_exp2,
  275. )
  276. else:
  277. dq, dk, dv, delta = attention_vanilla_backward_pytorch_ref_impl(
  278. do,
  279. q,
  280. k,
  281. v,
  282. o,
  283. softmax_lse,
  284. sm_scale,
  285. causal,
  286. layout,
  287. use_exp2,
  288. )
  289. return dq, dk, dv, delta