Răsfoiți Sursa

fix: calculate the key/value outputs with kvhead

AlpinDale 1 an în urmă
părinte
comite
89c7f0469f
3 a modificat fișierele cu 53 adăugiri și 18 ștergeri
  1. 37 17
      aphrodite/modeling/layers/attention.py
  2. 1 1
      aphrodite/processing/block_manager.py
  3. 15 0
      test.py

+ 37 - 17
aphrodite/modeling/layers/attention.py

@@ -38,12 +38,23 @@ class PagedAttention(nn.Module):
     5. Output a flattened 1D tensor.
     """
 
-    def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
+    def __init__(self,
+                 num_heads: int,
+                 head_size: int,
+                 scale: float,
+                 num_kv_heads: Optional[int] = None) -> None:
         super().__init__()
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
         self.attn_op = xops.fmha.cutlass.FwOp()
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        self.head_mapping = torch.repeat_interleave(
+            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
+            self.num_queries_per_kv)
 
         if self.head_size not in _SUPPORTED_HEAD_SIZES:
             raise ValueError(f"head_size ({self.head_size}) is not supported. "
@@ -70,11 +81,19 @@ class PagedAttention(nn.Module):
         Args:
             output: shape = [num_prompt_tokens, num_heads, head_size]
             query: shape = [num_prompt_tokens, num_heads, head_size]
-            key: shape = [num_prompt_tokens, num_heads, head_size]
-            value: shape = [num_prompt_tokens, num_heads, head_size]
+            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
+            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
             input_metadata: metadata for paged attention.
         """
-        #TODO: The unsqueeze op may incur some CPU overhead. See if optimzation is possible.
+
+        if self.num_kv_heads != self.num_heads:
+            # Project the key and value tensors to the desired number of heads.
+            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
+            value = torch.repeat_interleave(value,
+                                            self.num_queries_per_kv,
+                                            dim=1)
+
+        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
         out = xops.memory_efficient_attention_forward(
             query.unsqueeze(0),
             key.unsqueeze(0),
@@ -84,7 +103,7 @@ class PagedAttention(nn.Module):
             scale=self.scale,
             op=self.attn_op,
         )
-        # TODO: Unnecessary copy. Optimize.
+        # TODO(woosuk): Unnecessary copy. Optimize.
         output.copy_(out.squeeze(0))
         return output
 
@@ -101,9 +120,9 @@ class PagedAttention(nn.Module):
         Args:
             output: shape = [num_generation_tokens, num_heads, head_size]
             query: shape = [num_generation_tokens, num_heads, head_size]
-            key_cache: shape = [num_blocks, num_heads, head_size/x,
+            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
                 block_size, x]
-            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
+            value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size]
             input_metadata: metadata for paged attention.
         """
         block_size = value_cache.shape[3]
@@ -112,6 +131,7 @@ class PagedAttention(nn.Module):
             query,
             key_cache,
             value_cache,
+            self.head_mapping,
             self.scale,
             input_metadata.block_tables,
             input_metadata.context_lens,
@@ -137,11 +157,12 @@ class PagedAttention(nn.Module):
 
         Args:
             query: shape = [num_tokens, num_heads * head_size]
-            key: shape = [num_tokens, num_heads * head_size]
-            value: shape = [num_tokens, num_heads * head_size]
-            key_cache: shape = [num_blocks, num_heads, head_size/x,
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
                 block_size, x]
-            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
+            value_cache: shape = [num_blocks, num_kv_heads, head_size,
+                block_size]
             input_metadata: metadata for paged attention.
             cache_event: event to wait for the cache operations to finish.
 
@@ -151,8 +172,8 @@ class PagedAttention(nn.Module):
 
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_heads, self.head_size)
-        value = value.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
 
         # Pre-allocate the output tensor.
         output = torch.empty_like(query)
@@ -308,8 +329,7 @@ class PagedAttentionWithALiBi(PagedAttention):
             bias = bias.to(self.alibi_slopes.device)
 
             # When using custom attention bias, xformers requires the bias to
-            # be sliced from a tensor whose length is a multiple of 8. Likely 
-            # due to a CUDA Tensor code limitation.
+            # be sliced from a tensor whose length is a multiple of 8.
             padded_len = (prompt_len + 7) // 8 * 8
             bias = torch.empty(
                 self.num_heads,
@@ -338,7 +358,7 @@ class PagedAttentionWithALiBi(PagedAttention):
             value: shape = [num_prompt_tokens, num_heads, head_size]
             input_metadata: metadata for paged attention.
         """
-        # FIXME: Because xformers does not support dynamic sequence
+        # FIXME(woosuk): Because xformers does not support dynamic sequence
         # lengths with custom attention bias, we process each prompt one by
         # one. This is inefficient, especially when we have many short prompts.
         start = 0
@@ -353,7 +373,7 @@ class PagedAttentionWithALiBi(PagedAttention):
                 scale=self.scale,
                 op=self.attn_op,
             )
-            # TODO: Unnecessary copy. Optimize.
+            # TODO(woosuk): Unnecessary copy. Optimize.
             output[start:end].copy_(out.squeeze(0))
             start += prompt_len
         return output

+ 1 - 1
aphrodite/processing/block_manager.py

@@ -212,7 +212,7 @@ class BlockSpaceManager:
         if seq.seq_id not in self.block_tables:
             return
         block_table = self.block_tables[seq.seq_id]
-        self._free_block_table[block_table]
+        self._free_block_table(block_table)
         del self.block_tables[seq.seq_id]
 
     def reset(self) -> None:

+ 15 - 0
test.py

@@ -0,0 +1,15 @@
+from aphrodite import LLM, SamplingParams
+
+prompts = [
+  "What is a man? A",
+  "The sun is a wondrous body, like a magnificent",
+  "All flesh is grass and all the comeliness thereof",
+]
+sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
+
+llm = LLM(model="/home/alpindale/AI-Stuff/models/Pythia-70M")
+outputs = llm.generate(prompts, sampling_params)
+for output in outputs:
+  prompt = output.prompt
+  generated_text = output.outputs[0].text
+  print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")