Browse Source

fix: issues with flashinfer fp8 kv (#950)

AlpinDale 2 months ago
parent
commit
cef6da8863
2 changed files with 115 additions and 9 deletions
  1. 19 9
      aphrodite/attention/backends/flashinfer.py
  2. 96 0
      tests/models/test_fp8kv_flashinfer.py

+ 19 - 9
aphrodite/attention/backends/flashinfer.py

@@ -187,8 +187,12 @@ class FlashInferState(AttentionState):
             self._graph_decode_workspace_buffer, _indptr_buffer,
             self._graph_indices_buffer, _last_page_len_buffer, "NHD",
             use_tensor_cores)
-        kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
-            self.runner.kv_cache_dtype)
+        if self.runner.kv_cache_dtype.startswith("fp8"):
+            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
+                self.runner.kv_cache_dtype)
+        else:
+            kv_cache_dtype = get_kv_cache_torch_dtype(
+                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
         paged_kv_indptr_tensor_host = torch.arange(0,
                                                    batch_size + 1,
                                                    dtype=torch.int32)
@@ -348,7 +352,7 @@ class FlashInferMetadata(AttentionMetadata):
                 self.page_size,
                 # Disable flashinfer's pos encoding and use Aphrodite's rope.
                 pos_encoding_mode="NONE",
-            )
+                data_type=self.data_type)
 
     def asdict_zerocopy(self,
                         skip_fields: Optional[Set[str]] = None
@@ -584,8 +588,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
             paged_kv_indptr_tensor = None
             paged_kv_last_page_len_tensor = None
 
-        kv_cache_dtype = get_kv_cache_torch_dtype(
-            self.runner.kv_cache_dtype, self.runner.model_config.dtype)
+        if self.runner.kv_cache_dtype.startswith("fp8"):
+            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
+                self.runner.kv_cache_dtype)
+        else:
+            kv_cache_dtype = get_kv_cache_torch_dtype(
+                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
+
         return FlashInferMetadata(
             num_prefills=self.num_prefills,
             slot_mapping=slot_mapping_tensor,
@@ -685,10 +694,11 @@ class FlashInferImpl(AttentionImpl):
                 v_scale,
             )
             # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
-            # to process the cache in fp8
-            torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
-                self.kv_cache_dtype)
-            kv_cache = kv_cache.view(torch_dtype)
+            # to process the cache when the kv_cache_dtype is fp8
+            if self.kv_cache_dtype.startswith("fp8"):
+                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
+                    self.kv_cache_dtype)
+                kv_cache = kv_cache.view(torch_dtype)
 
         query = query.contiguous(
         )  # Flashinfer requires query to be contiguous

+ 96 - 0
tests/models/test_fp8kv_flashinfer.py

@@ -0,0 +1,96 @@
+# flake8: noqa
+"""Tests fp8 models against ground truth generation
+This verifies the flashinfer backend with fp8 
+quantization and fp8 KV Cache without scaling 
+factors Note: these tests will only pass on H100 GPU.
+"""
+import os
+from typing import List
+
+import pytest
+from transformers import AutoTokenizer
+
+from aphrodite import LLM, SamplingParams
+from tests.quantization.utils import is_quant_method_supported
+
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+MAX_MODEL_LEN = 1024
+
+MODELS = [
+    "nm-testing/Meta-Llama-3-8B-Instruct-FP8",
+]
+
+EXPECTED_STRS_MAP = {
+    "nm-testing/Meta-Llama-3-8B-Instruct-FP8": {
+        "auto": [
+            'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
+            'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
+            'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
+            'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
+            'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5',
+            'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
+            'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
+            'Here are the translations:\n\n**Japanese:** (Haya aki no tori, mushi o',
+        ],
+        "fp8": [
+            'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
+            'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
+            'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
+            'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
+            'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
+            'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
+            'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
+            'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
+        ]
+    }
+}
+
+
+# This test compares against golden strings for exact match since
+# there is no baseline implementation to compare against
+# and is unstable w.r.t specifics of the fp8 implementation or
+# the hardware being run on.
+# No assert to prevent it from breaking the build
+@pytest.mark.skipif(not is_quant_method_supported("fp8"),
+                    reason="fp8 is not supported on this GPU type.")
+@pytest.mark.parametrize("model_name", MODELS)
+@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
+@pytest.mark.parametrize("backend", ["XFORMERS", "FLASHINFER"])
+def test_models(example_prompts, model_name, kv_cache_dtype, backend) -> None:
+    # Note that the golden strings may not work for FLASHINFER Backend.
+    # The intention is to test the path
+    os.environ["VLLM_ATTENTION_BACKEND"] = backend
+    model = LLM(model=model_name,
+                max_model_len=MAX_MODEL_LEN,
+                trust_remote_code=True,
+                quantization="fp8",
+                kv_cache_dtype=kv_cache_dtype)
+
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+    formatted_prompts = [
+        tokenizer.apply_chat_template([{
+            "role": "user",
+            "content": prompt
+        }],
+                                      tokenize=False,
+                                      add_generation_prompt=True)
+        for prompt in example_prompts
+    ]
+
+    params = SamplingParams(max_tokens=20, temperature=0)
+    generations: List[str] = []
+    # Note: these need to be run 1 at a time due to numerical precision,
+    # since the expected strs were generated this way.
+    for prompt in formatted_prompts:
+        outputs = model.generate(prompt, params)
+        generations.append(outputs[0].outputs[0].text)
+    del model
+
+    print(f"Testing: {model_name} with kv_cache_dtype: {kv_cache_dtype}")
+    expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
+    for i in range(len(example_prompts)):
+        generated_str = generations[i]
+        expected_str = expected_strs[i]
+        print(f"generated_str\n: {generated_str}")
+        print(f"expected_str\n: {expected_str}")