|
@@ -383,7 +383,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
const int head_size_og = sizes[3];
|
|
|
const int seqlen_k = k.size(1);
|
|
|
const int num_heads_k = k.size(2);
|
|
|
- TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
|
|
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
|
|
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
|
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
|
|
|
|
@@ -1350,7 +1350,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
|
|
|
const int num_heads_k = kcache.size(2);
|
|
|
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
|
|
|
- TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
|
|
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
|
|
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
|
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
|
|
|