interface_torch.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. from .fwd_prefill import attention_prefill_forward_triton_impl
  3. from .bwd_prefill import attention_prefill_backward_triton_impl
  4. from .fwd_decode import attention_decode_forward_triton_impl
  5. class _attention_prefill(torch.autograd.Function):
  6. @staticmethod
  7. def forward(ctx, q, k, v, o, metadata):
  8. (output,
  9. softmax_lse,
  10. exp_scores,
  11. grid,
  12. head_size,
  13. philox_seed,
  14. philox_offset,
  15. _,
  16. _) = attention_prefill_forward_triton_impl(
  17. q,
  18. k,
  19. v,
  20. o,
  21. metadata.sm_scale,
  22. metadata.alibi_slopes,
  23. metadata.causal,
  24. metadata.bias,
  25. metadata.dropout_p,
  26. metadata.layout,
  27. metadata.cu_seqlens_q,
  28. metadata.cu_seqlens_k,
  29. metadata.max_seqlens_q,
  30. metadata.max_seqlens_k,
  31. metadata.return_scores,
  32. metadata.use_exp2)
  33. ctx.save_for_backward(q, k, v, o, softmax_lse)
  34. ctx.grid = grid
  35. ctx.sm_scale = metadata.sm_scale
  36. ctx.head_size = head_size
  37. ctx.causal = metadata.causal
  38. ctx.alibi_slopes = metadata.alibi_slopes
  39. ctx.dropout_p = metadata.dropout_p
  40. ctx.philox_seed = philox_seed
  41. ctx.philox_offset = philox_offset
  42. ctx.exp_scores = exp_scores
  43. ctx.return_scores = metadata.return_scores
  44. ctx.layout = metadata.layout
  45. ctx.use_exp2 = metadata.use_exp2
  46. return output, softmax_lse, exp_scores
  47. @staticmethod
  48. def backward(ctx, do, *args):
  49. q, k, v, o, softmax_lse = ctx.saved_tensors
  50. return attention_prefill_backward_triton_impl(
  51. do,
  52. q,
  53. k,
  54. v,
  55. o,
  56. softmax_lse,
  57. None,
  58. None,
  59. None,
  60. ctx.sm_scale,
  61. ctx.alibi_slopes,
  62. ctx.causal,
  63. ctx.layout,
  64. None,
  65. None,
  66. None,
  67. None,
  68. ctx.use_exp2
  69. )
  70. attention_prefill = _attention_prefill.apply
  71. class _attention_decode(torch.autograd.Function):
  72. @staticmethod
  73. def forward(ctx, q, k, v, metadata):
  74. output, softmax_lse = attention_decode_forward_triton_impl(
  75. q,
  76. k,
  77. v,
  78. metadata.sm_scale,
  79. metadata.causal,
  80. metadata.alibi_slopes,
  81. metadata.layout,
  82. metadata.cache_seqlens,
  83. metadata.cache_batch_idx,
  84. metadata.new_kv,
  85. metadata.k_new,
  86. metadata.v_new,
  87. )
  88. return output, softmax_lse
  89. attention_decode = _attention_decode.apply