123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- # Adapted from
- # https://github.com/fmmoret/aphrodite/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
- from dataclasses import dataclass
- from typing import List
- import pytest
- import aphrodite
- from aphrodite.lora.request import LoRARequest
- from .conftest import cleanup
- @dataclass
- class ModelWithQuantization:
- model_path: str
- quantization: str
- MODELS: List[ModelWithQuantization] = [
- ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
- quantization="AWQ"),
- ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
- quantization="GPTQ"),
- ]
- def do_sample(llm: aphrodite.LLM,
- lora_path: str,
- lora_id: int,
- max_tokens: int = 256) -> List[str]:
- raw_prompts = [
- "Give me an orange-ish brown color",
- "Give me a neon pink color",
- ]
- def format_prompt_tuples(prompt):
- return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
- prompts = [format_prompt_tuples(p) for p in raw_prompts]
- sampling_params = aphrodite.SamplingParams(temperature=0,
- max_tokens=max_tokens,
- stop=["<|im_end|>"])
- outputs = llm.generate(
- prompts,
- sampling_params,
- lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
- if lora_id else None)
- # Print the outputs.
- generated_texts: List[str] = []
- for output in outputs:
- prompt = output.prompt
- generated_text = output.outputs[0].text
- generated_texts.append(generated_text)
- print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
- return generated_texts
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize("tp_size", [1])
- def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
- # Cannot use as it will initialize torch.cuda too early...
- # if torch.cuda.device_count() < tp_size:
- # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
- llm = aphrodite.LLM(
- model=model.model_path,
- enable_lora=True,
- max_num_seqs=16,
- max_loras=4,
- max_model_len=400,
- tensor_parallel_size=tp_size,
- gpu_memory_utilization=0.2, #avoid OOM
- quantization=model.quantization,
- trust_remote_code=True)
- if model.quantization is None:
- expected_no_lora_output = [
- "Here are some examples of orange-brown colors",
- "I'm sorry, I don't have"
- ]
- expected_lora_output = [
- "#ff8050",
- "#ff8080",
- ]
- elif model.quantization == "AWQ":
- expected_no_lora_output = [
- "I'm sorry, I don't understand",
- "I'm sorry, I don't understand",
- ]
- expected_lora_output = [
- "#f07700: A v",
- "#f00000: A v",
- ]
- elif model.quantization == "GPTQ":
- expected_no_lora_output = [
- "I'm sorry, I don't have",
- "I'm sorry, I don't have",
- ]
- expected_lora_output = [
- "#f08800: This is",
- "#f07788 \n#",
- ]
- def expect_match(output, expected_output):
- # HACK: GPTQ lora outputs are just incredibly unstable.
- # Assert that the outputs changed.
- if (model.quantization == "GPTQ"
- and expected_output is expected_lora_output):
- assert output != expected_no_lora_output
- for i, o in enumerate(output):
- assert o.startswith(
- '#'), f"Expected example {i} to start with # but got {o}"
- return
- assert output == expected_output
- max_tokens = 10
- print("lora adapter created")
- output = do_sample(llm,
- tinyllama_lora_files,
- lora_id=0,
- max_tokens=max_tokens)
- expect_match(output, expected_no_lora_output)
- print("lora 1")
- output = do_sample(llm,
- tinyllama_lora_files,
- lora_id=1,
- max_tokens=max_tokens)
- expect_match(output, expected_lora_output)
- print("no lora")
- output = do_sample(llm,
- tinyllama_lora_files,
- lora_id=0,
- max_tokens=max_tokens)
- expect_match(output, expected_no_lora_output)
- print("lora 2")
- output = do_sample(llm,
- tinyllama_lora_files,
- lora_id=2,
- max_tokens=max_tokens)
- expect_match(output, expected_lora_output)
- print("removing lora")
- del llm
- cleanup()
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.skip("Requires multiple GPUs")
- def test_quant_model_tp_equality(tinyllama_lora_files, model):
- # Cannot use as it will initialize torch.cuda too early...
- # if torch.cuda.device_count() < 2:
- # pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
- llm_tp1 = aphrodite.LLM(
- model=model.model_path,
- enable_lora=True,
- max_num_seqs=16,
- max_loras=4,
- tensor_parallel_size=1,
- gpu_memory_utilization=0.2, #avoid OOM
- quantization=model.quantization,
- trust_remote_code=True)
- output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
- del llm_tp1
- cleanup()
- llm_tp2 = aphrodite.LLM(
- model=model.model_path,
- enable_lora=True,
- max_num_seqs=16,
- max_loras=4,
- tensor_parallel_size=2,
- gpu_memory_utilization=0.2, #avoid OOM
- quantization=model.quantization)
- output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
- del llm_tp2
- cleanup()
- assert output_tp1 == output_tp2
|