test_gemma.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from typing import List
  2. import aphrodite
  3. from aphrodite.lora.request import LoRARequest
  4. MODEL_PATH = "google/gemma-7b"
  5. def do_sample(llm: aphrodite.LLM, lora_path: str, lora_id: int) -> List[str]:
  6. prompts = [
  7. "Quote: Imagination is",
  8. "Quote: Be yourself;",
  9. "Quote: So many books,",
  10. ]
  11. sampling_params = aphrodite.SamplingParams(temperature=0, max_tokens=32)
  12. outputs = llm.generate(
  13. prompts,
  14. sampling_params,
  15. lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
  16. if lora_id else None)
  17. # Print the outputs.
  18. generated_texts: List[str] = []
  19. for output in outputs:
  20. prompt = output.prompt
  21. generated_text = output.outputs[0].text.strip()
  22. generated_texts.append(generated_text)
  23. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  24. return generated_texts
  25. def test_gemma_lora(gemma_lora_files):
  26. llm = aphrodite.LLM(MODEL_PATH,
  27. max_model_len=1024,
  28. enable_lora=True,
  29. max_loras=4)
  30. expected_lora_output = [
  31. "more important than knowledge.\nAuthor: Albert Einstein\n",
  32. "everyone else is already taken.\nAuthor: Oscar Wilde\n",
  33. "so little time.\nAuthor: Frank Zappa\n",
  34. ]
  35. output1 = do_sample(llm, gemma_lora_files, lora_id=1)
  36. for i in range(len(expected_lora_output)):
  37. assert output1[i].startswith(expected_lora_output[i])
  38. output2 = do_sample(llm, gemma_lora_files, lora_id=2)
  39. for i in range(len(expected_lora_output)):
  40. assert output2[i].startswith(expected_lora_output[i])