test_aqlm.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. """Compare the outputs of a AQLM model between Aphrodite and HF Transformers
  2. Run `pytest tests/models/test_aqlm.py`.
  3. """
  4. import pytest
  5. from tests.quantization.utils import is_quant_method_supported
  6. # These ground truth generations were generated using `transformers==4.38.1
  7. # aqlm==1.1.0 torch==2.2.0`
  8. # and the below code:
  9. # ```python
  10. # from transformers import AutoTokenizer, AutoModelForCausalLM
  11. # model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
  12. # quantized_model = AutoModelForCausalLM.from_pretrained(model_id,
  13. # torch_dtype="auto", device_map="cuda").cuda()
  14. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  15. # outputs = []
  16. # for prompt in example_prompts:
  17. # input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
  18. # hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32)
  19. # outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:]))
  20. # print(outputs)
  21. # ```
  22. ground_truth_generations = [
  23. '\n### Features\n\n- **High-throughput**: v',
  24. 'The major milestones in the development of artificial intelligence from '
  25. '195',
  26. 'Compare and contrast artificial intelligence with human intelligence in '
  27. 'terms of processing information. The',
  28. 'Explain the difference between supervised and unsupervised learning.'
  29. '\nExplain',
  30. 'Write a short story about a robot that dreams for the first time. The',
  31. 'Analyze the impact of the COVID-19 pandemic on global economic',
  32. 'The Mona Lisa is a painting by Leonardo da Vinci, and it',
  33. 'The early bird catches the worm.\nThe early bird catches the'
  34. ]
  35. @pytest.mark.skipif(not is_quant_method_supported("aqlm"),
  36. reason="AQLM is not supported on this GPU type.")
  37. @pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
  38. @pytest.mark.parametrize("dtype", ["half"])
  39. @pytest.mark.parametrize("max_tokens", [16])
  40. @pytest.mark.parametrize("num_logprobs", [1])
  41. def test_models(
  42. aphrodite_runner,
  43. example_prompts,
  44. model: str,
  45. dtype: str,
  46. max_tokens: int,
  47. num_logprobs: int,
  48. ) -> None:
  49. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  50. aphrodite_outputs = aphrodite_model.generate_greedy_logprobs(
  51. example_prompts, max_tokens, num_logprobs)
  52. # loop through the prompts to compare against the ground truth generations
  53. for prompt_idx in range(len(example_prompts)):
  54. (aphrodite_output_ids, aphrodite_output_str,
  55. aphrodite_logprobs) = aphrodite_outputs[
  56. prompt_idx]
  57. print("Prompt: ", repr(example_prompts[prompt_idx]))
  58. print("Reference output:", repr(ground_truth_generations[prompt_idx]))
  59. print("Output output: ", repr(aphrodite_output_str))
  60. assert aphrodite_output_str == ground_truth_generations[prompt_idx]