12345678910111213141516 |
- #include "flash_fwd_combine_launch_template.h"
- template void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream);
- template void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream);
- template void run_mha_fwd_combine_<float, float, 256>(Flash_fwd_params ¶ms, cudaStream_t stream);
- template void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream);
- template void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream);
- template void run_mha_fwd_combine_<cutlass::half_t, float, 256>(Flash_fwd_params ¶ms, cudaStream_t stream);
- template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream);
- template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream);
- template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 256>(Flash_fwd_params ¶ms, cudaStream_t stream);
|