test_bitsandbytes.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. '''Tests whether bitsandbytes computation is enabled correctly.
  2. Run `pytest tests/quantization/test_bitsandbytes.py`.
  3. '''
  4. import gc
  5. import pytest
  6. import torch
  7. from tests.quantization.utils import is_quant_method_supported
  8. models_4bit_to_test = [
  9. ('huggyllama/llama-7b', 'quantize model inflight'),
  10. ]
  11. models_pre_qaunt_4bit_to_test = [
  12. ('lllyasviel/omost-llama-3-8b-4bits',
  13. 'read pre-quantized 4-bit NF4 model'),
  14. ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
  15. 'read pre-quantized 4-bit FP4 model'),
  16. ]
  17. models_pre_quant_8bit_to_test = [
  18. ('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
  19. ]
  20. @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
  21. reason='bitsandbytes is not supported on this GPU type.')
  22. @pytest.mark.parametrize("model_name, description", models_4bit_to_test)
  23. def test_load_4bit_bnb_model(hf_runner, aphrodite_runner, example_prompts,
  24. model_name, description) -> None:
  25. hf_model_kwargs = {"load_in_4bit": True}
  26. validate_generated_texts(hf_runner, aphrodite_runner, example_prompts[:1],
  27. model_name, hf_model_kwargs)
  28. @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
  29. reason='bitsandbytes is not supported on this GPU type.')
  30. @pytest.mark.parametrize("model_name, description",
  31. models_pre_qaunt_4bit_to_test)
  32. def test_load_pre_quant_4bit_bnb_model(hf_runner, aphrodite_runner,
  33. example_prompts,
  34. model_name, description) -> None:
  35. validate_generated_texts(hf_runner, aphrodite_runner, example_prompts[:1],
  36. model_name)
  37. @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
  38. reason='bitsandbytes is not supported on this GPU type.')
  39. @pytest.mark.parametrize("model_name, description",
  40. models_pre_quant_8bit_to_test)
  41. def test_load_8bit_bnb_model(hf_runner, aphrodite_runner, example_prompts,
  42. model_name, description) -> None:
  43. validate_generated_texts(hf_runner, aphrodite_runner, example_prompts[:1],
  44. model_name)
  45. def log_generated_texts(prompts, outputs, runner_name):
  46. logged_texts = []
  47. for i, (_, generated_text) in enumerate(outputs):
  48. log_entry = {
  49. "prompt": prompts[i],
  50. "runner_name": runner_name,
  51. "generated_text": generated_text,
  52. }
  53. logged_texts.append(log_entry)
  54. return logged_texts
  55. def validate_generated_texts(hf_runner,
  56. aphrodite_runner,
  57. prompts,
  58. model_name,
  59. hf_model_kwargs=None):
  60. if hf_model_kwargs is None:
  61. hf_model_kwargs = {}
  62. # Run with HF runner
  63. with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
  64. hf_outputs = llm.generate_greedy(prompts, 8)
  65. hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
  66. # Clean up the GPU memory for the next test
  67. torch.cuda.synchronize()
  68. gc.collect()
  69. torch.cuda.empty_cache()
  70. #Run with Aphrodite runner
  71. with aphrodite_runner(model_name,
  72. quantization='bitsandbytes',
  73. load_format='bitsandbytes',
  74. enforce_eager=True,
  75. gpu_memory_utilization=0.8) as llm:
  76. aphrodite_outputs = llm.generate_greedy(prompts, 8)
  77. aphrodite_logs = log_generated_texts(prompts, aphrodite_outputs,
  78. "AphroditeRunner")
  79. # Clean up the GPU memory for the next test
  80. torch.cuda.synchronize()
  81. gc.collect()
  82. torch.cuda.empty_cache()
  83. # Compare the generated strings
  84. for hf_log, aphrodite_log in zip(hf_logs, aphrodite_logs):
  85. hf_str = hf_log["generated_text"]
  86. aphrodite_str = aphrodite_log["generated_text"]
  87. prompt = hf_log["prompt"]
  88. assert hf_str == aphrodite_str, (f"Model: {model_name}"
  89. f"Mismatch between HF and Aphrodite "
  90. "outputs:\n"
  91. f"Prompt: {prompt}\n"
  92. f"HF Output: '{hf_str}'\n"
  93. f"Aphrodite Output: '{aphrodite_str}'")