test_aqlm.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """Compare the outputs of a AQLM model between Aphrodite and HF Transformers
  2. Run `pytest tests/models/test_aqlm.py`.
  3. """
  4. import pytest
  5. from tests.quantization.utils import is_quant_method_supported
  6. # In this test we hardcode prompts and generations for the model so we don't
  7. # need to require the AQLM package as a dependency
  8. example_prompts = [
  9. 'Aphrodite is a high-throughput and memory-efficient inference and serving '
  10. 'engine for LLMs.\n',
  11. 'Briefly describe the major milestones in the development of artificial '
  12. 'intelligence from 1950 to 2020.\n',
  13. 'Compare and contrast artificial intelligence with human intelligence in '
  14. 'terms of processing information.\n',
  15. 'Describe the basic components of a neural network and how it can be '
  16. 'trained.\n',
  17. 'Write a short story about a robot that dreams for the first time.\n',
  18. 'Analyze the impact of the COVID-19 pandemic on global economic structures '
  19. 'and future business models.\n',
  20. 'Explain the cultural significance of the Mona Lisa painting, and how its '
  21. 'perception might vary in Western versus Eastern societies.\n',
  22. "Translate the following English sentence into Japanese, French, and "
  23. "Swahili: 'The early bird catches the worm.'\n"
  24. ]
  25. # These ground truth generations were generated using `transformers==4.38.1
  26. # aqlm==1.1.0 torch==2.2.0`
  27. # and the below code:
  28. # ```python
  29. # from transformers import AutoTokenizer, AutoModelForCausalLM
  30. # model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
  31. # quantized_model = AutoModelForCausalLM.from_pretrained(model_id,
  32. # torch_dtype="auto", device_map="cuda").cuda()
  33. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  34. # outputs = []
  35. # for prompt in example_prompts:
  36. # input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
  37. # hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32)
  38. # outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:]))
  39. # print(outputs)
  40. # ```
  41. ground_truth_generations = [
  42. '\n### Features\n\n- **High-throughput**: v',
  43. 'The major milestones in the development of artificial intelligence from '
  44. '195',
  45. 'Compare and contrast artificial intelligence with human intelligence in '
  46. 'terms of processing information. The',
  47. 'Explain the difference between supervised and unsupervised learning.'
  48. '\nExplain',
  49. 'Write a short story about a robot that dreams for the first time. The',
  50. 'Analyze the impact of the COVID-19 pandemic on global economic',
  51. 'The Mona Lisa is a painting by Leonardo da Vinci, and it',
  52. 'The early bird catches the worm.\nThe early bird catches the'
  53. ]
  54. @pytest.mark.skipif(not is_quant_method_supported("aqlm"),
  55. reason="AQLM is not supported on this GPU type.")
  56. @pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
  57. @pytest.mark.parametrize("dtype", ["half"])
  58. @pytest.mark.parametrize("max_tokens", [16])
  59. @pytest.mark.parametrize("num_logprobs", [1])
  60. def test_models(
  61. aphrodite_runner,
  62. example_prompts,
  63. model: str,
  64. dtype: str,
  65. max_tokens: int,
  66. num_logprobs: int,
  67. ) -> None:
  68. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  69. aphrodite_outputs = aphrodite_model.generate_greedy_logprobs(
  70. example_prompts, max_tokens, num_logprobs)
  71. # loop through the prompts to compare against the ground truth generations
  72. for prompt_idx in range(len(example_prompts)):
  73. aphrodite_output_ids, aphrodite_output_str, aphrodite_logprobs = (
  74. aphrodite_outputs[prompt_idx])
  75. print("Prompt: ", repr(example_prompts[prompt_idx]))
  76. print("Reference output:", repr(ground_truth_generations[prompt_idx]))
  77. print("Output output: ", repr(aphrodite_output_str))
  78. assert aphrodite_output_str == ground_truth_generations[prompt_idx]