123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- /******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
- #pragma once
- #include <cuda.h>
- #include <vector>
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- struct Qkv_params {
- using index_t = int64_t;
- // The QKV matrices.
- void *__restrict__ q_ptr;
- void *__restrict__ k_ptr;
- void *__restrict__ v_ptr;
- // The stride between rows of the Q, K and V matrices.
- index_t q_batch_stride;
- index_t k_batch_stride;
- index_t v_batch_stride;
- index_t q_row_stride;
- index_t k_row_stride;
- index_t v_row_stride;
- index_t q_head_stride;
- index_t k_head_stride;
- index_t v_head_stride;
- index_t v_dim_stride;
- // The number of heads.
- int h, h_k;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- struct Flash_fwd_params : public Qkv_params {
- using index_t = int64_t;
- // The O matrix (output).
- void * __restrict__ o_ptr;
- void * __restrict__ oaccum_ptr;
- // The stride between rows of O.
- index_t o_batch_stride;
- index_t o_row_stride;
- index_t o_head_stride;
- // The pointer to the softmax sum.
- void * __restrict__ softmax_lse_ptr;
- void * __restrict__ softmax_lseaccum_ptr;
- // For FP8 scaling
- float * __restrict__ q_descale_ptr;
- float * __restrict__ k_descale_ptr;
- float * __restrict__ v_descale_ptr;
- index_t q_descale_batch_stride;
- index_t q_descale_head_stride;
- index_t k_descale_batch_stride;
- index_t k_descale_head_stride;
- index_t v_descale_batch_stride;
- index_t v_descale_head_stride;
- // The dimensions.
- int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
- int total_q, total_k, total_knew;
- int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
- // The scaling factors for the kernel.
- float scale_softmax;
- float softcap;
- // array of length b+1 holding starting offset of each sequence.
- int * __restrict__ cu_seqlens_q;
- int * __restrict__ cu_seqlens_k;
- int * __restrict__ cu_seqlens_knew;
- int * __restrict__ leftpad_k;
- // If provided, the actual length of each q/k sequence.
- int *__restrict__ seqused_q;
- int *__restrict__ seqused_k;
- // The stride between rows of Oaccum.
- index_t oaccum_split_stride;
- index_t oaccum_batch_stride;
- index_t oaccum_row_stride;
- index_t oaccum_head_stride;
- // The stride between rows of LSEaccum.
- index_t lseaccum_split_stride;
- index_t lseaccum_batch_stride;
- index_t lseaccum_head_stride;
- // The K_new and V_new matrices.
- void * __restrict__ knew_ptr;
- void * __restrict__ vnew_ptr;
- // The stride between rows of the Q, K and V matrices.
- index_t knew_batch_stride;
- index_t vnew_batch_stride;
- index_t knew_row_stride;
- index_t vnew_row_stride;
- index_t knew_head_stride;
- index_t vnew_head_stride;
- // The cos and sin matrices for rotary embedding.
- void * __restrict__ rotary_cos_ptr;
- void * __restrict__ rotary_sin_ptr;
- // The indices to index into the KV cache.
- int * __restrict__ kv_batch_idx;
- // Paged KV cache
- int * __restrict__ page_table;
- index_t page_table_batch_stride;
- int page_size;
- int num_pages;
- // The dropout probability (probability of keeping an activation).
- float p_dropout;
- // uint32_t p_dropout_in_uint;
- // uint16_t p_dropout_in_uint16_t;
- uint8_t p_dropout_in_uint8_t;
- // Scale factor of 1 / (1 - p_dropout).
- float rp_dropout;
- // Local window size
- int window_size_left, window_size_right;
- int sink_token_length;
- // Pointer to the RNG seed (idx 0) and offset (idx 1).
- uint64_t * rng_state;
- bool is_bf16;
- bool is_fp32;
- bool is_e4m3;
- bool is_causal;
- bool is_local;
- bool is_rotary_interleaved;
- int num_splits; // For split-KV version
- bool pack_gqa;
- int * __restrict__ tile_count_semaphore;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- struct Flash_bwd_params : public Flash_fwd_params {
- using index_t = int64_t;
- // The dO and dQKV matrices.
- void *__restrict__ do_ptr;
- void *__restrict__ dq_ptr;
- void *__restrict__ dk_ptr;
- void *__restrict__ dv_ptr;
- // To accumulate dQ
- void *__restrict__ dq_accum_ptr;
- void *__restrict__ dk_accum_ptr;
- void *__restrict__ dv_accum_ptr;
- // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
- // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
- // dv_accum_ptr;
- // The stride between rows of the dO, dQ, dK and dV matrices.
- index_t do_batch_stride;
- index_t do_row_stride;
- index_t do_head_stride;
- index_t dq_batch_stride;
- index_t dk_batch_stride;
- index_t dv_batch_stride;
- index_t dq_row_stride;
- index_t dk_row_stride;
- index_t dv_row_stride;
- index_t dq_head_stride;
- index_t dk_head_stride;
- index_t dv_head_stride;
- // The pointer to the softmax d sum.
- void *__restrict__ dsoftmax_sum;
- void *__restrict__ softmax_lse_log2_ptr;
- int *__restrict__ dq_semaphore;
- int *__restrict__ dk_semaphore;
- int *__restrict__ dv_semaphore;
- bool deterministic;
- index_t dq_accum_split_stride;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename T, int Headdim, bool Split, bool PagedKV> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
- template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
- template<typename T, typename Tpartial, int Headdim> void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream);
|