123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348 |
- /******************************************************************************
- * Copyright (c) 2024, Tri Dao.
- ******************************************************************************/
- #include "flash_common.hpp"
- #include "fmha_fwd.hpp"
- #include "mask.hpp"
- fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
- std::string dtype,
- int head_size,
- bool has_dropout,
- bool has_lse,
- bool enable_alibi)
- {
- return fmha_fwd_traits{head_size,
- head_size,
- dtype,
- true, // is_group_mode
- true, // is_v_rowmajor
- mask.type,
- enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
- has_lse,
- has_dropout,
- false}; // do_fp8_static_quant
- }
- fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
- bool has_dropout_randval,
- const mask_info &mask,
- // sizes
- const int b,
- const int max_seqlen_q,
- const int h,
- const int h_k,
- const int d,
- // device pointers
- const at::Tensor q,
- const at::Tensor k,
- const at::Tensor v,
- const at::Tensor seqlens_q,
- const at::Tensor seqlens_k,
- c10::optional<at::Tensor> &alibi_slopes_,
- at::Tensor out,
- at::Tensor softmax_lse,
- at::Tensor dropout_randval,
- float softmax_scale,
- float p_dropout,
- uint64_t drop_seed,
- uint64_t drop_offset)
- {
- // q: (total_q, nheads, d)
- // k: (total_k, nheads_k, d)
- // v: (total_k, nheads_k, d)
- // o: (total_q, nheads, d)
- // alibi_slopes:(batch, nheads) or (nhead)
- // lse: (nheads, total_q)
- // randval: (nheads, total_q, max_seqlen_k)
- ck_tile::index_t total_q = q.size(0);
- ck_tile::index_t total_k = k.size(0);
- ck_tile::index_t stride_q = q.stride(0);
- ck_tile::index_t stride_k = k.stride(0);
- ck_tile::index_t stride_v = v.stride(0);
- ck_tile::index_t stride_o = out.stride(0);
- ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
- ck_tile::index_t nhead_stride_q = q.stride(1);
- ck_tile::index_t nhead_stride_k = k.stride(1);
- ck_tile::index_t nhead_stride_v = v.stride(1);
- ck_tile::index_t nhead_stride_o = out.stride(1);
- ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
- ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
- ck_tile::index_t batch_stride_q = 0;
- ck_tile::index_t batch_stride_k = 0;
- ck_tile::index_t batch_stride_v = 0;
- ck_tile::index_t batch_stride_o = 0;
- ck_tile::index_t batch_stride_lse = 0;
- ck_tile::index_t batch_stride_randval = 0;
- void *alibi_slopes_ptr = nullptr;
- ck_tile::index_t stride_alibi_slopes = 0;
- if (alibi_slopes_.has_value()) {
- auto alibi_slopes = alibi_slopes_.value();
- CHECK_DEVICE(alibi_slopes);
- TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
- TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
- alibi_slopes_ptr = alibi_slopes.data_ptr();
- stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
- }
- return fmha_fwd_args{q.data_ptr(),
- k.data_ptr(),
- v.data_ptr(),
- alibi_slopes_ptr, // bias
- has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
- has_lse ? softmax_lse.data_ptr() : nullptr,
- out.data_ptr(),
- seqlens_q.data_ptr(), // seqstart_q
- seqlens_k.data_ptr(), // seqstart_k
- nullptr, // seqlen_kpads
- total_q,
- total_k,
- b,
- max_seqlen_q,
- d, // hdim_q
- d, // hdim_v
- h, // nhead
- h_k, // nhead_k
- softmax_scale, // scale_s
- 1, // scale_p
- 1, // scale_o
- stride_q,
- stride_k,
- stride_v,
- stride_alibi_slopes,
- stride_randval,
- stride_o,
- nhead_stride_q,
- nhead_stride_k,
- nhead_stride_v,
- 0, // nhead_stride_bias, FA without bias
- nhead_stride_randval,
- nhead_stride_lse,
- nhead_stride_o,
- batch_stride_q,
- batch_stride_k,
- batch_stride_v,
- 0, // batch_stride_bias, FA without bias
- batch_stride_randval,
- batch_stride_lse,
- batch_stride_o,
- mask.left,
- mask.right,
- static_cast<ck_tile::index_t>(mask.type),
- p_dropout,
- has_dropout_randval,
- {drop_seed, drop_offset}};
- }
- std::vector<at::Tensor>
- mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
- const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
- const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
- c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
- const at::Tensor &cu_seqlens_q, // b+1
- const at::Tensor &cu_seqlens_k, // b+1
- c10::optional<at::Tensor> & /*seqused_k*/,
- c10::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
- c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
- c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
- int max_seqlen_q,
- const int max_seqlen_k,
- const float p_dropout,
- const float softmax_scale,
- const bool zero_tensors,
- bool is_causal,
- int window_size_left,
- int window_size_right,
- const float /*softcap*/,
- const bool return_dropout_randval,
- c10::optional<at::Generator> gen_)
- {
- auto q_dtype = q.dtype();
- TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
- "FlashAttention only support fp16 and bf16 data type");
- TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
- TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
- TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
- TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
- std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
- CHECK_DEVICE(cu_seqlens_q);
- CHECK_DEVICE(cu_seqlens_k);
- // TODO - Support paged_KV
- const bool paged_KV = block_table_.has_value();
- TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet");
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- CHECK_CONTIGUOUS(cu_seqlens_q);
- CHECK_CONTIGUOUS(cu_seqlens_k);
- const auto sizes = q.sizes();
- const int batch_size = cu_seqlens_q.numel() - 1;
- int num_heads = sizes[1];
- const int head_size = sizes[2];
- const int num_heads_k = k.size(1);
- const int max_num_blocks_per_seq = 0;
- const int num_blocks = 0;
- if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
- // TODO
- // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
- // H/t Daniel Haziza
- const int total_q = q.size(0);
- const int total_k = k.size(0);
- TORCH_CHECK(batch_size > 0, "batch size must be postive");
- TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256");
- TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
- if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
- if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
- mask_info mask;
- if (is_causal) {
- // Causal is the special case where window_size_right == 0 and window_size_left < 0.
- window_size_right = 0;
- std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
- mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
- }
- else if (window_size_left == -1 && window_size_right == -1) {
- mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
- }
- else {
- // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
- std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
- mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
- }
- CHECK_SHAPE(q, total_q, num_heads, head_size);
- CHECK_SHAPE(k, total_k, num_heads_k, head_size);
- CHECK_SHAPE(v, total_k, num_heads_k, head_size);
- CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
- CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
- at::Tensor out;
- if (out_.has_value()) {
- out = out_.value();
- TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
- CHECK_DEVICE(out);
- TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
- CHECK_SHAPE(out, total_q, num_heads, head_size);
- }
- else {
- out = torch::empty_like(q);
- }
- // Otherwise the kernel will be launched from cuda:0 device
- // Cast to char to avoid compiler warning about narrowing
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
- auto opts = q.options();
- bool has_lse = true;
- bool has_dropout = p_dropout > 0.0f;
- at::Tensor softmax_lse;
- // TODO - check gradient, only training require lse
- softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(torch::kFloat32));
- at::Tensor p;
- if (return_dropout_randval) {
- TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
- p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
- }
- else {
- p = torch::empty({ 0 }, opts);
- }
- if (zero_tensors)
- {
- out.zero_();
- softmax_lse.fill_(-std::numeric_limits<float>::infinity());
- if (return_dropout_randval) {p.zero_();}
- }
- uint64_t drop_seed = 1, drop_offset = 0;
- int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
- auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
- auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
- if (p_dropout > 0.0) {
- auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
- gen_, at::cuda::detail::getDefaultCUDAGenerator());
- // See Note [Acquire lock when using random generators]
- std::lock_guard<std::mutex> lock(gen->mutex_);
- auto philox_args = gen->philox_cuda_state(counter_offset);
- std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
- }
- rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
- rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
- if (max_seqlen_k > 0) {
- auto stream = at::cuda::getCurrentHIPStream().stream();
- ck_tile::stream_config stream_config{stream};
- auto traits =
- get_ck_fmha_varlen_fwd_traits(
- mask,
- q_dtype_str,
- head_size,
- has_dropout,
- has_lse,
- alibi_slopes_.has_value());
- auto args =
- get_ck_fmha_varlen_fwd_args(
- has_lse,
- return_dropout_randval,
- mask,
- batch_size,
- max_seqlen_q,
- num_heads,
- num_heads_k,
- head_size,
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- alibi_slopes_,
- out,
- softmax_lse,
- p,
- softmax_scale,
- p_dropout,
- drop_seed,
- drop_offset);
- float t = fmha_fwd(traits, args, stream_config);
- TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
- }
- else {
- // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
- out.zero_();
- softmax_lse.fill_(std::numeric_limits<float>::infinity());
- }
- return {out, softmax_lse, p, rng_state};
- }
|