/****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include "flash_common.hpp" #include "fmha_fwd.hpp" #include "mask.hpp" fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, std::string dtype, int head_size, bool has_dropout, bool has_lse, bool enable_alibi) { return fmha_fwd_traits{head_size, head_size, dtype, false, // is_group_mode true, // is_v_rowmajor mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, false}; // do_fp8_static_quant } fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, bool has_dropout_randval, const mask_info &mask, // sizes const int b, const int seqlen_q, const int seqlen_k, const int h, const int h_k, const int d, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, c10::optional &alibi_slopes_, at::Tensor out, at::Tensor softmax_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, uint64_t drop_seed, uint64_t drop_offset) { // q: (batch_size, seqlen_q, nheads, d) // k: (batch_size, seqlen_k, nheads_k, d) // v: (batch_size, seqlen_k, nheads_k, d) // o: (batch_size, seqlen_q, nheads, d) // alibi_slopes:(batch_size, nheads) or (nhead) // lse: (batch_size, nheads, seqlen_q) // randval: (batch_size, nheads, seqlen_q, seqlen_k) ck_tile::index_t stride_q = q.stride(1); ck_tile::index_t stride_k = k.stride(1); ck_tile::index_t stride_v = v.stride(1); ck_tile::index_t stride_o = out.stride(1); ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0; ck_tile::index_t nhead_stride_q = q.stride(2); ck_tile::index_t nhead_stride_k = k.stride(2); ck_tile::index_t nhead_stride_v = v.stride(2); ck_tile::index_t nhead_stride_o = out.stride(2); ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; ck_tile::index_t batch_stride_q = q.stride(0); ck_tile::index_t batch_stride_k = k.stride(0); ck_tile::index_t batch_stride_v = v.stride(0); ck_tile::index_t batch_stride_o = out.stride(0); ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; void *alibi_slopes_ptr = nullptr; ck_tile::index_t stride_alibi_slopes = 0; if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); alibi_slopes_ptr = alibi_slopes.data_ptr(); stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } return fmha_fwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), nullptr, // seqstart_q nullptr, // seqstart_k nullptr, seqlen_q, seqlen_k, b, seqlen_q, // max_seqlen_q d, // hdim_q d, // hdim_v h, // nhead h_k, // nhead_k softmax_scale, // scale_s 1, // scale_p 1, // scale_o stride_q, stride_k, stride_v, stride_alibi_slopes, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, 0, // nhead_stride_bias, FA without bias nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, 0, // batch_stride_bias, FA without bias batch_stride_randval, batch_stride_lse, batch_stride_o, mask.left, mask.right, static_cast(mask.type), p_dropout, has_dropout_randval, {drop_seed, drop_offset}}; } std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const float /*softcap*/, const bool return_dropout_randval, c10::optional gen_) { auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); const auto sizes = q.sizes(); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; const int head_size_og = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } mask_info mask; if (is_causal) { // Causal is the special case where window_size_right == 0 and window_size_left < 0. window_size_right = 0; std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual } else if (window_size_left == -1 && window_size_right == -1) { mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask } else { // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); seqlen_q = ngroups; num_heads = num_heads_k; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); at::Tensor q_padded, k_padded, v_padded; if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); } else { q_padded = q; k_padded = k; v_padded = v; } at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); if (seqlenq_ngroups_swapped) { out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { out = torch::empty_like(q_padded); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_8x = round_multiple(head_size_og, 8); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); bool has_lse = true; bool has_dropout = p_dropout > 0.0f; at::Tensor softmax_lse; // TODO - check gradient, only training require lse softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32)); at::Tensor p; if (return_dropout_randval) { TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8)); } uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); } rng_state[0] = *(reinterpret_cast(&drop_seed)); rng_state[1] = *(reinterpret_cast(&drop_offset)); if (seqlen_k > 0) { auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; auto traits = get_ck_fmha_fwd_traits( mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value()); auto args = get_ck_fmha_fwd_args( has_lse, return_dropout_randval, mask, batch_size, seqlen_q, seqlen_k, num_heads, num_heads_k, head_size_8x, q_padded, k_padded, v_padded, alibi_slopes_, out, softmax_lse, p, softmax_scale, p_dropout, drop_seed, drop_offset); float t = fmha_fwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); softmax_lse.fill_(std::numeric_limits::infinity()); } at::Tensor out_padded = out; if (head_size_og % 8 != 0) { out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); if (out_.has_value()) { out_.value().copy_(out); } } if (seqlenq_ngroups_swapped) { out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; }