|
@@ -13,7 +13,9 @@
|
|
|
#include "flash.h"
|
|
|
#include "static_switch.h"
|
|
|
|
|
|
+#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
|
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
|
|
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
|
|
|
|
|
|
|
|
void set_params_fprop(Flash_fwd_params ¶ms,
|
|
@@ -260,9 +262,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
|
|
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
|
|
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
|
|
|
|
|
- TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
|
|
|
|
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
@@ -299,7 +299,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
|
|
if (out_.has_value()) {
|
|
|
out = out_.value();
|
|
|
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
|
|
- TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
|
|
+ CHECK_DEVICE(out);
|
|
|
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
|
|
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
|
|
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
|
@@ -426,17 +426,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
|
|
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
|
|
|
|
|
- TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
|
|
|
- TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
|
|
|
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
|
|
+ CHECK_DEVICE(cu_seqlens_q);
|
|
|
+ CHECK_DEVICE(cu_seqlens_k);
|
|
|
|
|
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
- TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
|
|
|
- TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
|
|
|
+ CHECK_CONTIGUOUS(cu_seqlens_q);
|
|
|
+ CHECK_CONTIGUOUS(cu_seqlens_k);
|
|
|
|
|
|
const auto sizes = q.sizes();
|
|
|
|
|
@@ -471,7 +469,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
if (out_.has_value()) {
|
|
|
out = out_.value();
|
|
|
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
|
|
- TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
|
|
+ CHECK_DEVICE(out);
|
|
|
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
|
|
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
|
|
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
|
@@ -610,12 +608,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|
|
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
|
|
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
|
|
|
|
|
- TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
|
|
|
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
|
|
+ CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
|
|
|
|
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
@@ -657,7 +651,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|
|
if (dq_.has_value()) {
|
|
|
dq = dq_.value();
|
|
|
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
|
|
- TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
|
|
|
+ CHECK_DEVICE(dq);
|
|
|
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
|
|
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
|
|
} else {
|
|
@@ -666,7 +660,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|
|
if (dk_.has_value()) {
|
|
|
dk = dk_.value();
|
|
|
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
|
|
- TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
|
|
|
+ CHECK_DEVICE(dk);
|
|
|
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
|
|
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
|
|
} else {
|
|
@@ -675,7 +669,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|
|
if (dv_.has_value()) {
|
|
|
dv = dv_.value();
|
|
|
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
|
|
- TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
|
|
|
+ CHECK_DEVICE(dv);
|
|
|
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
|
|
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
|
|
|
} else {
|
|
@@ -820,22 +814,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|
|
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
|
|
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
|
|
|
|
|
- TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
|
|
|
- TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
|
|
|
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
|
|
+ CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
|
|
+ CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
|
|
|
|
|
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
|
|
- TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
|
|
|
- TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
|
|
|
+ CHECK_CONTIGUOUS(cu_seqlens_q);
|
|
|
+ CHECK_CONTIGUOUS(cu_seqlens_k);
|
|
|
|
|
|
const auto sizes = q.sizes();
|
|
|
|
|
@@ -873,7 +862,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|
|
if (dq_.has_value()) {
|
|
|
dq = dq_.value();
|
|
|
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
|
|
- TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
|
|
|
+ CHECK_DEVICE(dq);
|
|
|
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
|
|
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
|
|
} else {
|
|
@@ -882,7 +871,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|
|
if (dk_.has_value()) {
|
|
|
dk = dk_.value();
|
|
|
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
|
|
- TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
|
|
|
+ CHECK_DEVICE(dk);
|
|
|
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
|
|
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
|
|
|
} else {
|
|
@@ -891,7 +880,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|
|
if (dv_.has_value()) {
|
|
|
dv = dv_.value();
|
|
|
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
|
|
- TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
|
|
|
+ CHECK_DEVICE(dv);
|
|
|
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
|
|
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
|
|
|
} else {
|
|
@@ -1000,9 +989,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
|
|
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
|
|
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
|
|
|
+ c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
|
|
+ c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
|
|
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
|
|
const float softmax_scale,
|
|
|
bool is_causal,
|
|
|
+ bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
|
|
int num_splits
|
|
|
) {
|
|
|
|
|
@@ -1023,9 +1015,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
|
|
TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
|
|
|
|
|
|
- TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(kcache.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(vcache.is_cuda(), "Input tensor must be on CUDA device");
|
|
|
+ CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
|
|
|
|
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
|
@@ -1071,7 +1061,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
if (out_.has_value()) {
|
|
|
out = out_.value();
|
|
|
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
|
|
- TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
|
|
+ CHECK_DEVICE(out);
|
|
|
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
|
|
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
|
|
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
|
@@ -1118,8 +1108,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
v = v_.value();
|
|
|
TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
|
|
|
TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
|
|
|
- TORCH_CHECK(k.is_cuda(), "Key tensor must be on CUDA device");
|
|
|
- TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device");
|
|
|
+ CHECK_DEVICE(k); CHECK_DEVICE(v);
|
|
|
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
|
|
|
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
|
|
|
int seqlen_knew = k.size(1);
|
|
@@ -1147,13 +1136,40 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
if (seqlens_k_.has_value()) {
|
|
|
auto seqlens_k = seqlens_k_.value();
|
|
|
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
|
|
- TORCH_CHECK(seqlens_k.is_cuda(), "seqlens_k must be on CUDA device");
|
|
|
- TORCH_CHECK(seqlens_k.is_contiguous(), "seqlens_k must be contiguous");
|
|
|
+ CHECK_DEVICE(seqlens_k);
|
|
|
+ CHECK_CONTIGUOUS(seqlens_k);
|
|
|
CHECK_SHAPE(seqlens_k, batch_size);
|
|
|
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
|
|
|
}
|
|
|
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
|
|
|
|
|
|
+ if (rotary_cos_.has_value()) {
|
|
|
+ TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
|
|
|
+ auto rotary_cos = rotary_cos_.value();
|
|
|
+ CHECK_DEVICE(rotary_cos);
|
|
|
+ params.rotary_dim = rotary_cos.size(1) * 2;
|
|
|
+ TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
|
|
|
+ TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
|
|
|
+ const int seqlen_ro = rotary_cos.size(0);
|
|
|
+ TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
|
|
|
+ CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
|
|
|
+ CHECK_CONTIGUOUS(rotary_cos);
|
|
|
+ TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
|
|
|
+
|
|
|
+ TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
|
|
|
+ auto rotary_sin = rotary_sin_.value();
|
|
|
+ CHECK_DEVICE(rotary_sin);
|
|
|
+ CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
|
|
|
+ CHECK_CONTIGUOUS(rotary_sin);
|
|
|
+ TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
|
|
|
+ params.rotary_cos_ptr = rotary_cos.data_ptr();
|
|
|
+ params.rotary_sin_ptr = rotary_sin.data_ptr();
|
|
|
+ params.is_rotary_interleaved = is_rotary_interleaved;
|
|
|
+ } else {
|
|
|
+ params.rotary_dim = 0;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
// This needs to match with run_mha_fwd_splitkv_dispatch
|
|
|
const int block_n = is_sm90 || is_sm8x
|
|
|
? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
|