Bladeren bron

[Doc] Change total -> total_q

Tri Dao 1 jaar geleden
bovenliggende
commit
e8a0b4acdd
1 gewijzigde bestanden met toevoegingen van 2 en 2 verwijderingen
  1. 2 2
      flash_attn/flash_attn_interface.py

+ 2 - 2
flash_attn/flash_attn_interface.py

@@ -279,7 +279,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
            (they might not have the right scaling).
         deterministic: bool. Whether or not to ensure deterministic execution.
     Return:
-        out: (total, nheads, headdim).
+        out: (total_q, nheads, headdim).
         softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
             logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
             normalization factor).
@@ -315,7 +315,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
            (they might not have the right scaling).
         deterministic: bool. Whether or not to ensure deterministic execution.
     Return:
-        out: (total, nheads, headdim).
+        out: (total_q, nheads, headdim).
         softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
             logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
             normalization factor).