flash_bwd_hdim96_fp16_softcap_sm80.cu 741 B

123456789101112131415161718
  1. // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  2. // Splitting the different template instantiations to different files to speed up compilation.
  3. // This file is auto-generated. See "generate_kernels.py"
  4. #include "flash_bwd_launch_template.h"
  5. #ifndef FLASHATTENTION_DISABLE_SM8x
  6. #ifndef FLASHATTENTION_DISABLE_HDIM96
  7. template<>
  8. void run_mha_bwd_<80, cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
  9. run_mha_bwd_hdim96<80, cutlass::half_t, true>(params, stream);
  10. }
  11. template<>
  12. void run_mha_bwd_<86, cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
  13. run_mha_bwd_hdim96<86, cutlass::half_t, true>(params, stream);
  14. }
  15. #endif
  16. #endif