test_pa_lora.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from huggingface_hub import snapshot_download
  2. from aphrodite import AphroditeEngine, EngineArgs, SamplingParams
  3. from aphrodite.lora.request import LoRARequest
  4. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  5. MODEL_PATH = "meta-llama/Llama-2-7b-hf"
  6. pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
  7. lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
  8. def do_sample(engine):
  9. prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
  10. # first prompt with a prompt adapter and second without adapter
  11. prompts = [
  12. (prompt_text,
  13. SamplingParams(temperature=0.0, max_tokens=100,
  14. stop=["[/assistant]"]),
  15. PromptAdapterRequest("hate_speech", 1, pa_path,
  16. 8), LoRARequest("sql_test", 1, lora_path)),
  17. (prompt_text,
  18. SamplingParams(temperature=0.0, max_tokens=100,
  19. stop=["[/assistant]"]), None,
  20. LoRARequest("sql_test", 1, lora_path)),
  21. ]
  22. request_id = 0
  23. results = set()
  24. while prompts or engine.has_unfinished_requests():
  25. if prompts:
  26. prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
  27. engine.add_request(str(request_id),
  28. prompt,
  29. sampling_params,
  30. prompt_adapter_request=pa_request,
  31. lora_request=lora_request)
  32. request_id += 1
  33. request_outputs = engine.step()
  34. for request_output in request_outputs:
  35. if request_output.finished:
  36. results.add(request_output.outputs[0].text)
  37. return results
  38. def test_lora_prompt_adapter():
  39. engine_args = EngineArgs(model=MODEL_PATH,
  40. enable_prompt_adapter=True,
  41. enable_lora=True,
  42. max_num_seqs=60,
  43. max_prompt_adapter_token=8)
  44. engine = AphroditeEngine.from_engine_args(engine_args)
  45. result = do_sample(engine)
  46. expected_output = {
  47. " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
  48. }
  49. assert result == expected_output