12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046 |
- /******************************************************************************
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- ******************************************************************************/
- // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
- #include <torch/python.h>
- #include <torch/nn/functional.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/cuda/CUDAGuard.h>
- #include <cutlass/numeric_types.h>
- #include "flash.h"
- #include "static_switch.h"
- #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
- void set_params_fprop(Flash_fwd_params ¶ms,
- // sizes
- const size_t b,
- const size_t seqlen_q,
- const size_t seqlen_k,
- const size_t seqlen_q_rounded,
- const size_t seqlen_k_rounded,
- const size_t h,
- const size_t h_k,
- const size_t d,
- const size_t d_rounded,
- // device pointers
- const at::Tensor q,
- const at::Tensor k,
- const at::Tensor v,
- at::Tensor out,
- void *cu_seqlens_q_d,
- void *cu_seqlens_k_d,
- void *seqused_q,
- void *seqused_k,
- void *p_d,
- void *softmax_lse_d,
- float p_dropout,
- float softmax_scale,
- int window_size_left,
- int window_size_right,
- bool seqlenq_ngroups_swapped=false,
- bool unpadded_lse=false,
- bool optimize_for_doc_masking=false) {
- // Reset the parameters
- params = {};
- params.is_bf16 = q.dtype() == torch::kBFloat16;
- params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
- // Set the pointers and strides.
- params.q_ptr = q.data_ptr();
- params.k_ptr = k.data_ptr();
- params.v_ptr = v.data_ptr();
- // All stride are in elements, not bytes.
- params.q_row_stride = q.stride(-3);
- params.k_row_stride = k.stride(-3);
- params.v_row_stride = v.stride(-3);
- params.q_head_stride = q.stride(-2);
- params.k_head_stride = k.stride(-2);
- params.v_head_stride = v.stride(-2);
- params.o_ptr = out.data_ptr();
- params.o_row_stride = out.stride(-3);
- params.o_head_stride = out.stride(-2);
- if (cu_seqlens_q_d == nullptr) {
- params.q_batch_stride = q.stride(0);
- params.k_batch_stride = k.stride(0);
- params.v_batch_stride = v.stride(0);
- params.o_batch_stride = out.stride(0);
- if (seqlenq_ngroups_swapped) {
- params.q_batch_stride *= seqlen_q;
- params.o_batch_stride *= seqlen_q;
- }
- }
- params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
- params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
- params.seqused_q = static_cast<int *>(seqused_q);
- params.seqused_k = static_cast<int *>(seqused_k);
- TORCH_CHECK(
- bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k),
- "cu_seqlens_q and cu_seqlens_k must be both null or non-null"
- );
- // P = softmax(QK^T)
- params.p_ptr = p_d;
- // Softmax sum
- params.softmax_lse_ptr = softmax_lse_d;
- // Set the dimensions.
- params.b = b;
- params.h = h;
- params.h_k = h_k;
- params.h_h_k_ratio = h / h_k;
- params.seqlen_q = seqlen_q;
- params.seqlen_k = seqlen_k;
- params.seqlen_q_rounded = seqlen_q_rounded;
- params.seqlen_k_rounded = seqlen_k_rounded;
- params.d = d;
- params.d_rounded = d_rounded;
- // Set the different scale values.
- params.scale_softmax = softmax_scale;
- params.scale_softmax_log2 = softmax_scale * M_LOG2E;
- __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
- __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
- params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
- // Set this to probability of keeping an element to simplify things.
- params.p_dropout = 1.f - p_dropout;
- // Convert p from float to int so we don't have to convert the random uint to float to compare.
- // [Minor] We want to round down since when we do the comparison we use <= instead of <
- // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
- // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
- params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
- params.rp_dropout = 1.f / params.p_dropout;
- params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
- TORCH_CHECK(p_dropout < 1.f);
- #ifdef FLASHATTENTION_DISABLE_DROPOUT
- TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
- #endif
- // Causal is the special case where window_size_right == 0 and window_size_left < 0.
- // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
- window_size_left = std::min(int(seqlen_k), window_size_left);
- window_size_right = std::min(int(seqlen_k), window_size_right);
- if (window_size_left < 0) { window_size_left = seqlen_k; }
- if (window_size_right < 0) { window_size_right = seqlen_k; }
- params.window_size_left = window_size_left;
- params.window_size_right = window_size_right;
- params.is_causal = window_size_left == seqlen_k && window_size_right == 0;
- if ((window_size_left < seqlen_k || window_size_right < seqlen_k) && !params.is_causal) {
- params.is_local = true;
- }
- #ifdef FLASHATTENTION_DISABLE_LOCAL
- TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
- "This flash attention build does not support local attention.");
- #endif
- params.is_seqlens_k_cumulative = true;
- #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
- TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
- #endif
- params.unpadded_lse = unpadded_lse;
- params.optimize_for_doc_masking = optimize_for_doc_masking;
- }
- void set_params_dgrad(Flash_bwd_params ¶ms,
- // sizes
- const size_t b,
- const size_t seqlen_q,
- const size_t seqlen_k,
- const size_t seqlen_q_rounded,
- const size_t seqlen_k_rounded,
- const size_t h,
- const size_t h_k,
- const size_t d,
- const size_t d_rounded,
- // device pointers
- const at::Tensor q,
- const at::Tensor k,
- const at::Tensor v,
- const at::Tensor out,
- const at::Tensor dout,
- at::Tensor dq,
- at::Tensor dk,
- at::Tensor dv,
- void *cu_seqlens_q_d,
- void *cu_seqlens_k_d,
- void *seqused_q,
- void *seqused_k,
- void *dq_accum_d,
- void *dk_accum_d,
- void *dv_accum_d,
- void *softmax_lse_d,
- void *dsoftmax_sum_d,
- float p_dropout,
- float softmax_scale,
- int window_size_left,
- int window_size_right,
- bool deterministic,
- bool seqlenq_ngroups_swapped=false,
- bool unpadded_lse=false,
- bool optimize_for_doc_masking=false) {
- set_params_fprop(params,
- b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
- q, k, v, out,
- cu_seqlens_q_d,
- cu_seqlens_k_d,
- seqused_q,
- seqused_k,
- nullptr,
- softmax_lse_d,
- p_dropout,
- softmax_scale,
- window_size_left,
- window_size_right,
- seqlenq_ngroups_swapped,
- unpadded_lse,
- optimize_for_doc_masking);
- // Set the pointers and strides.
- params.do_ptr = dout.data_ptr();
- params.do_row_stride = dout.stride(-3);
- params.do_head_stride = dout.stride(-2);
- params.dq_ptr = dq.data_ptr();
- params.dk_ptr = dk.data_ptr();
- params.dv_ptr = dv.data_ptr();
- params.dq_row_stride = dq.stride(-3);
- params.dk_row_stride = dk.stride(-3);
- params.dv_row_stride = dv.stride(-3);
- params.dq_head_stride = dq.stride(-2);
- params.dk_head_stride = dk.stride(-2);
- params.dv_head_stride = dv.stride(-2);
- if (cu_seqlens_q_d == nullptr) {
- params.do_batch_stride = dout.stride(0);
- params.dq_batch_stride = dq.stride(0);
- params.dk_batch_stride = dk.stride(0);
- params.dv_batch_stride = dv.stride(0);
- }
- params.dq_accum_ptr = dq_accum_d;
- params.dk_accum_ptr = dk_accum_d;
- params.dv_accum_ptr = dv_accum_d;
- // Softmax sum
- params.dsoftmax_sum = dsoftmax_sum_d;
- params.deterministic = deterministic;
- }
- void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
- // HEADDIM_SWITCH(params.d, [&] {
- // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
- // });
- if (!params.is_e4m3) {
- if (params.is_bf16) {
- if (params.d == 64) {
- run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
- } else if (params.d == 128) {
- run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
- } else {
- run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
- }
- } else {
- if (params.d == 64) {
- run_mha_fwd_<cutlass::half_t, 64>(params, stream);
- } else if (params.d == 128) {
- run_mha_fwd_<cutlass::half_t, 128>(params, stream);
- } else {
- run_mha_fwd_<cutlass::half_t, 256>(params, stream);
- }
- }
- } else {
- if (params.d == 64) {
- run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
- } else if (params.d == 128) {
- run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
- } else if (params.d == 256) {
- run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
- }
- }
- }
- std::vector<at::Tensor>
- mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
- c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
- const float softmax_scale,
- c10::optional<at::Tensor> &descale_q_, // 1
- c10::optional<at::Tensor> &descale_k_, // 1
- c10::optional<at::Tensor> &descale_v_, // 1
- bool is_causal,
- int window_size_left,
- int window_size_right) {
- auto dprops = at::cuda::getCurrentDeviceProperties();
- bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
- TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
- auto q_dtype = q.dtype();
- // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
- // "FlashAttention only support fp16 and bf16 data type for now");
- // TODO: will add e4m3 later
- // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
- // "FlashAttention only support fp16 and bf16 data type");
- // "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
- TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
- TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- const auto sizes = q.sizes();
- const int batch_size = sizes[0];
- int seqlen_q = sizes[1];
- int num_heads = sizes[2];
- const int head_size_og = sizes[3];
- const int seqlen_k = k.size(1);
- const int num_heads_k = k.size(2);
- TORCH_CHECK(batch_size > 0, "batch size must be positive");
- TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
- TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
- CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
- CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
- CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
- at::Tensor q_padded, k_padded, v_padded;
- if (head_size_og % 8 != 0) {
- q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
- k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
- v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
- } else {
- q_padded = q;
- k_padded = k;
- v_padded = v;
- }
- at::Tensor out;
- if (out_.has_value()) {
- out = out_.value();
- // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
- TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
- ? (out.dtype() == at::kHalf)
- : (out.dtype() == q_dtype),
- "Output must have the same dtype as input dtype if dtype is "
- "not fp8, or fp16 for fp8 input.");
- CHECK_DEVICE(out);
- TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
- CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
- if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
- } else {
- if (q_dtype == at::ScalarType::Float8_e4m3fn)
- out = torch::empty_like(q_padded, at::kHalf);
- else
- out = torch::empty_like(q_padded);
- }
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
- const int head_size = round_multiple(head_size_og, 8);
- const int head_size_rounded = round_multiple(head_size, 32);
- const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
- const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
- if (is_causal) { window_size_right = 0; }
- // Otherwise the kernel will be launched from cuda:0 device
- // Cast to char to avoid compiler warning about narrowing
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
- auto opts = q.options();
- auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
- at::Tensor p;
- Flash_fwd_params params;
- set_params_fprop(params,
- batch_size,
- seqlen_q, seqlen_k,
- seqlen_q_rounded, seqlen_k_rounded,
- num_heads, num_heads_k,
- head_size, head_size_rounded,
- q_padded, k_padded, v_padded, out,
- /*cu_seqlens_q_d=*/nullptr,
- /*cu_seqlens_k_d=*/nullptr,
- /*seqused_q=*/nullptr,
- /*seqused_k=*/nullptr,
- nullptr,
- softmax_lse.data_ptr(),
- /*p_dropout=*/0.f,
- softmax_scale,
- /*window_size_left=*/window_size_left,
- /*window_size_right=*/window_size_right);
- auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
- params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
- if(q_dtype == at::ScalarType::Float8_e4m3fn) {
- at::Tensor descale_q, descale_k, descale_v;
- if (descale_q_.has_value() && descale_k_.has_value() && descale_k_.has_value()) {
- descale_q = descale_q_.value();
- descale_k = descale_k_.value();
- descale_v = descale_v_.value();
- CHECK_DEVICE(descale_q);
- CHECK_DEVICE(descale_k);
- CHECK_DEVICE(descale_v);
- CHECK_SHAPE(descale_q, 1);
- CHECK_SHAPE(descale_k, 1);
- CHECK_SHAPE(descale_v, 1);
- } else {
- descale_q = torch::ones({1}, opts.dtype(at::kFloat));
- descale_k = torch::ones({1}, opts.dtype(at::kFloat));
- descale_v = torch::ones({1}, opts.dtype(at::kFloat));
- }
- params.descale_q_ptr = descale_q.data_ptr<float>();
- params.descale_k_ptr = descale_k.data_ptr<float>();
- params.descale_v_ptr = descale_v.data_ptr<float>();
- } else {
- params.descale_q_ptr = nullptr;
- params.descale_k_ptr = nullptr;
- params.descale_v_ptr = nullptr;
- }
- if (seqlen_k > 0) {
- auto stream = at::cuda::getCurrentCUDAStream().stream();
- run_mha_fwd(params, stream);
- } else {
- // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
- out.zero_();
- softmax_lse.fill_(std::numeric_limits<float>::infinity());
- }
- at::Tensor out_padded = out;
- if (head_size_og % 8 != 0) {
- out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
- if (out_.has_value()) { out_.value().copy_(out); }
- }
- return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
- }
- std::vector<at::Tensor>
- mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
- const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
- const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
- c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
- const at::Tensor &cu_seqlens_q, // b+1
- const at::Tensor &cu_seqlens_k, // b+1
- c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
- c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
- int max_seqlen_q,
- const int max_seqlen_k,
- const float softmax_scale,
- bool is_causal,
- int window_size_left,
- int window_size_right,
- bool optimize_for_doc_masking) {
- auto dprops = at::cuda::getCurrentDeviceProperties();
- bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
- TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
- auto q_dtype = q.dtype();
- TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
- "FlashAttention only support fp16 and bf16 data type");
- TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
- TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
- TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
- TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
- CHECK_DEVICE(cu_seqlens_q);
- CHECK_DEVICE(cu_seqlens_k);
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- CHECK_CONTIGUOUS(cu_seqlens_q);
- CHECK_CONTIGUOUS(cu_seqlens_k);
- const auto sizes = q.sizes();
- const int batch_size = cu_seqlens_q.numel() - 1;
- int num_heads = sizes[1];
- const int head_size_og = sizes[2];
- const int num_heads_k = k.size(1);
- void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
- const int total_q = q.sizes()[0];
- TORCH_CHECK(batch_size > 0, "batch size must be positive");
- TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
- CHECK_SHAPE(q, total_q, num_heads, head_size_og);
- const int total_k = k.size(0);
- CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
- CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
- CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
- if (seqused_q.has_value()){
- auto seqused_q_ = seqused_q.value();
- TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
- TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
- TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
- CHECK_SHAPE(seqused_q_, batch_size);
- }
- CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
- if (seqused_k.has_value()){
- auto seqused_k_ = seqused_k.value();
- TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
- TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
- TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
- CHECK_SHAPE(seqused_k_, batch_size);
- }
- at::Tensor q_padded, k_padded, v_padded;
- if (head_size_og % 8 != 0) {
- q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
- k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
- v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
- } else {
- q_padded = q;
- k_padded = k;
- v_padded = v;
- }
- at::Tensor out;
- if (out_.has_value()) {
- out = out_.value();
- TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
- CHECK_DEVICE(out);
- TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
- CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
- if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
- } else {
- out = torch::empty_like(q_padded);
- }
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
- const int head_size = round_multiple(head_size_og, 8);
- const int head_size_rounded = round_multiple(head_size, 32);
- const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
- const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
- if (is_causal) { window_size_right = 0; }
- // Otherwise the kernel will be launched from cuda:0 device
- // Cast to char to avoid compiler warning about narrowing
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
- auto opts = q.options();
- auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
- Flash_fwd_params params;
- set_params_fprop(params,
- batch_size,
- max_seqlen_q, max_seqlen_k,
- seqlen_q_rounded, seqlen_k_rounded,
- num_heads, num_heads_k,
- head_size, head_size_rounded,
- q_padded, k_padded, v_padded, out,
- cu_seqlens_q_d,
- cu_seqlens_k.data_ptr(),
- seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
- seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
- /*p_d=*/nullptr,
- softmax_lse.data_ptr(),
- /*p_dropout=*/0.f,
- softmax_scale,
- window_size_left,
- window_size_right,
- /*seqlenq_ngroups_swapped=*/false,
- /*unpadded_lse=*/true,
- /*optimize_for_doc_masking=*/optimize_for_doc_masking);
- params.total_q = total_q;
- params.total_k = total_k;
- if (max_seqlen_k > 0) {
- auto stream = at::cuda::getCurrentCUDAStream().stream();
- run_mha_fwd(params, stream);
- } else {
- // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
- out.zero_();
- softmax_lse.fill_(std::numeric_limits<float>::infinity());
- }
- at::Tensor out_padded = out;
- if (head_size_og % 8 != 0) {
- out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
- if (out_.has_value()) { out_.value().copy_(out); }
- }
- return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
- }
- void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
- // FP16_SWITCH(!params.is_bf16, [&] {
- // HEADDIM_SWITCH(params.d, [&] {
- // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
- // });
- // });
- if (!params.is_bf16) {
- if (params.d <= 64) {
- run_mha_bwd_<cutlass::half_t, 64>(params, stream);
- } else if (params.d <= 96) {
- run_mha_bwd_<cutlass::half_t, 96>(params, stream);
- } else {
- run_mha_bwd_<cutlass::half_t, 128>(params, stream);
- }
- } else {
- if (params.d <= 64) {
- run_mha_bwd_<cutlass::bfloat16_t, 64>(params, stream);
- } else if (params.d <= 96) {
- run_mha_bwd_<cutlass::bfloat16_t, 96>(params, stream);
- } else {
- run_mha_bwd_<cutlass::bfloat16_t, 128>(params, stream);
- }
- }
- }
- std::vector<at::Tensor>
- mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
- const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
- const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
- const at::Tensor &softmax_lse, // b x h x seqlen_q
- c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
- c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
- c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
- const float softmax_scale,
- const bool is_causal,
- int window_size_left,
- int window_size_right,
- const bool deterministic) {
- #ifdef FLASHATTENTION_DISABLE_BACKWARD
- TORCH_CHECK(false, "This flash attention build does not support backward.");
- #endif
- auto dprops = at::cuda::getCurrentDeviceProperties();
- bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
- TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
- auto stream = at::cuda::getCurrentCUDAStream().stream();
- auto q_dtype = q.dtype();
- TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
- "FlashAttention only support fp16 and bf16 data type");
- TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
- TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
- TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
- TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
- CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
- TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
- const auto sizes = q.sizes();
- const int batch_size = sizes[0];
- const int seqlen_q = sizes[1];
- const int num_heads = sizes[2];
- const int head_size_og = dout.size(3);
- const int head_size = sizes[3];
- const int seqlen_k = k.size(1);
- const int num_heads_k = k.size(2);
- TORCH_CHECK(batch_size > 0, "batch size must be positive");
- TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
- TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
- const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
- // This should match the kernel configs
- const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
- const int seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
- const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
- TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
- CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
- CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
- CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
- CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
- CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
- at::Tensor dq, dk, dv;
- if (dq_.has_value()) {
- dq = dq_.value();
- TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
- CHECK_DEVICE(dq);
- TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
- CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
- } else {
- dq = torch::empty_like(q);
- }
- if (dk_.has_value()) {
- dk = dk_.value();
- TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
- CHECK_DEVICE(dk);
- TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
- CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
- } else {
- dk = torch::empty_like(k);
- }
- if (dv_.has_value()) {
- dv = dv_.value();
- TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
- CHECK_DEVICE(dv);
- TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
- CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
- } else {
- dv = torch::empty_like(v);
- }
- at::Tensor dout_padded;
- if (head_size_og % 8 != 0) {
- dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
- } else {
- dout_padded = dout;
- }
- // Otherwise the kernel will be launched from cuda:0 device
- // Cast to char to avoid compiler warning about narrowing
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
- auto opts = q.options();
- // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
- auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
- auto softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
- at::Tensor dq_accum;
- at::Tensor dk_accum, dv_accum;
- dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
- // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
- // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
- at::Tensor dk_expanded, dv_expanded;
- if (num_heads_k != num_heads) { // MQA / GQA
- dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
- dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
- } else {
- dk_expanded = dk;
- dv_expanded = dv;
- }
- if (is_causal) { window_size_right = 0; }
- Flash_bwd_params params;
- set_params_dgrad(params,
- batch_size,
- seqlen_q, seqlen_k,
- seqlen_q_rounded, seqlen_k_rounded,
- num_heads, num_heads_k,
- head_size, head_size_rounded,
- q, k, v, out,
- dout_padded, dq, dk_expanded, dv_expanded,
- /*cu_seqlens_q_d=*/nullptr,
- /*cu_seqlens_k_d=*/nullptr,
- /*seqused_q=*/nullptr,
- /*seqused_k=*/nullptr,
- dq_accum.data_ptr(),
- // loop ? dk_accum.data_ptr() : nullptr,
- // loop ? dv_accum.data_ptr() : nullptr,
- nullptr,
- nullptr,
- softmax_lse.data_ptr(),
- softmax_d.data_ptr(),
- /*p_dropout=*/0.f,
- softmax_scale,
- /*window_size_left=*/window_size_left,
- /*window_size_right=*/window_size_right,
- deterministic);
- params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
- // Will be zero'ed out in the backward preprocess kernel
- at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
- params.dq_semaphore = dq_semaphore.data_ptr<int>();
- // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
- if (seqlen_q > 0) {
- run_mha_bwd(params, stream);
- } else {
- // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
- dk_expanded.zero_();
- dv_expanded.zero_();
- softmax_d.zero_();
- }
- // For MQA/GQA we need to sum dK and dV across the groups
- if (num_heads_k != num_heads) {
- at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
- at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
- }
- if (head_size_og % 8 != 0) {
- dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
- dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
- dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
- }
- return { dq, dk, dv, softmax_d, dq_accum};
- }
- std::vector<at::Tensor>
- mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
- const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
- const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
- const at::Tensor &softmax_lse, // b x h x seqlen_q
- c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
- c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
- c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
- const at::Tensor &cu_seqlens_q, // b+1
- const at::Tensor &cu_seqlens_k, // b+1
- c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
- c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
- const int max_seqlen_q,
- const int max_seqlen_k, // max sequence length to choose the kernel
- const float softmax_scale,
- const bool is_causal,
- int window_size_left,
- int window_size_right,
- const bool deterministic,
- const bool optimize_for_doc_masking) {
- #ifdef FLASHATTENTION_DISABLE_BACKWARD
- TORCH_CHECK(false, "This flash attention build does not support backward.");
- #endif
- auto dprops = at::cuda::getCurrentDeviceProperties();
- bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
- TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
- auto stream = at::cuda::getCurrentCUDAStream().stream();
- auto q_dtype = q.dtype();
- TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
- "FlashAttention only support fp16 and bf16 data type");
- TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
- TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
- TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
- TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
- TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
- TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
- CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
- CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
- TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
- CHECK_CONTIGUOUS(cu_seqlens_q);
- CHECK_CONTIGUOUS(cu_seqlens_k);
- const auto sizes = q.sizes();
- const int total_q = sizes[0];
- const int batch_size = cu_seqlens_q.numel() - 1;
- const int num_heads = sizes[1];
- const int head_size_og = dout.size(2);
- const int head_size = sizes[2];
- const int total_k = k.size(0);
- const int num_heads_k = k.size(1);
- TORCH_CHECK(batch_size > 0, "batch size must be positive");
- TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
- TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
- const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
- // This should match the kernel configs
- const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
- const int seqlen_q_rounded = round_multiple(max_seqlen_q, kBlockM);
- const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
- int const total_q_padded_rounded = round_multiple(total_q + batch_size * 128, 128);
- TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
- CHECK_SHAPE(q, total_q, num_heads, head_size_og);
- CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
- CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
- CHECK_SHAPE(out, total_q, num_heads, head_size);
- CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
- CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
- if (seqused_q.has_value()){
- auto seqused_q_ = seqused_q.value();
- TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
- TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
- TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
- CHECK_SHAPE(seqused_q_, batch_size);
- }
- CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
- if (seqused_k.has_value()){
- auto seqused_k_ = seqused_k.value();
- TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
- TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
- TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
- CHECK_SHAPE(seqused_k_, batch_size);
- }
- at::Tensor dq, dk, dv;
- if (dq_.has_value()) {
- dq = dq_.value();
- TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
- CHECK_DEVICE(dq);
- TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
- CHECK_SHAPE(dq, total_q, num_heads, head_size);
- } else {
- dq = torch::empty_like(q);
- }
- if (dk_.has_value()) {
- dk = dk_.value();
- TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
- CHECK_DEVICE(dk);
- TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
- CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
- } else {
- dk = torch::empty_like(k);
- }
- if (dv_.has_value()) {
- dv = dv_.value();
- TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
- CHECK_DEVICE(dv);
- TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
- CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
- } else {
- dv = torch::empty_like(v);
- }
- at::Tensor dout_padded;
- if (head_size_og % 8 != 0) {
- dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
- } else {
- dout_padded = dout;
- }
- if (is_causal) { window_size_right = 0; }
- // Otherwise the kernel will be launched from cuda:0 device
- // Cast to char to avoid compiler warning about narrowing
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
- auto opts = q.options();
- // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
- auto softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
- auto softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
- at::Tensor dq_accum;
- at::Tensor dk_accum, dv_accum;
- dq_accum = torch::empty({num_heads, total_q_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
- // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
- // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
- at::Tensor dk_expanded, dv_expanded;
- if (num_heads_k != num_heads) { // MQA / GQA
- dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
- dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
- } else {
- dk_expanded = dk;
- dv_expanded = dv;
- }
- Flash_bwd_params params;
- set_params_dgrad(params,
- batch_size,
- max_seqlen_q, max_seqlen_k,
- seqlen_q_rounded, seqlen_k_rounded,
- num_heads, num_heads_k,
- head_size, head_size_rounded,
- q, k, v, out,
- dout_padded, dq, dk_expanded, dv_expanded,
- cu_seqlens_q.data_ptr(),
- cu_seqlens_k.data_ptr(),
- seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
- seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
- dq_accum.data_ptr(),
- // loop ? dk_accum.data_ptr() : nullptr,
- // loop ? dv_accum.data_ptr() : nullptr,
- nullptr,
- nullptr,
- softmax_lse.data_ptr(),
- softmax_d.data_ptr(),
- /*p_dropout=*/0.f,
- softmax_scale,
- /*window_size_left=*/window_size_left,
- /*window_size_right=*/window_size_right,
- deterministic,
- /*seqlenq_ngroups_swapped=*/false,
- /*unpadded_lse=*/true,
- optimize_for_doc_masking);
- params.total_q = total_q;
- params.total_k = total_k;
- params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
- // Will be zero'ed out in the backward preprocess kernel
- at::Tensor dq_semaphore = torch::empty({(max_seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
- params.dq_semaphore = dq_semaphore.data_ptr<int>();
- if (max_seqlen_q > 0) {
- run_mha_bwd(params, stream);
- } else {
- // If max_seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
- dk_expanded.zero_();
- dv_expanded.zero_();
- softmax_d.zero_();
- }
- // For MQA/GQA we need to sum dK and dV across the groups
- if (num_heads_k != num_heads) {
- at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
- at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
- }
- if (head_size_og % 8 != 0) {
- dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
- dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
- dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
- }
- return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 };
- }
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.doc() = "FlashAttention";
- m.def("fwd", &mha_fwd, "Forward pass");
- m.def("bwd", &mha_bwd, "Backward pass");
- m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
- m.def("varlen_bwd", &mha_varlen_bwd, "Varlen backward pass");
- }
|