12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- from huggingface_hub import snapshot_download
- from aphrodite import AphroditeEngine, EngineArgs, SamplingParams
- from aphrodite.lora.request import LoRARequest
- from aphrodite.prompt_adapter.request import PromptAdapterRequest
- MODEL_PATH = "meta-llama/Llama-2-7b-hf"
- pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
- lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
- def do_sample(engine):
- prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
- # first prompt with a prompt adapter and second without adapter
- prompts = [
- (prompt_text,
- SamplingParams(temperature=0.0, max_tokens=100,
- stop=["[/assistant]"]),
- PromptAdapterRequest("hate_speech", 1, pa_path,
- 8), LoRARequest("sql_test", 1, lora_path)),
- (prompt_text,
- SamplingParams(temperature=0.0, max_tokens=100,
- stop=["[/assistant]"]), None,
- LoRARequest("sql_test", 1, lora_path)),
- ]
- request_id = 0
- results = set()
- while prompts or engine.has_unfinished_requests():
- if prompts:
- prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
- engine.add_request(str(request_id),
- prompt,
- sampling_params,
- prompt_adapter_request=pa_request,
- lora_request=lora_request)
- request_id += 1
- request_outputs = engine.step()
- for request_output in request_outputs:
- if request_output.finished:
- results.add(request_output.outputs[0].text)
- return results
- def test_lora_prompt_adapter():
- engine_args = EngineArgs(model=MODEL_PATH,
- enable_prompt_adapter=True,
- enable_lora=True,
- max_num_seqs=60,
- max_prompt_adapter_token=8)
- engine = AphroditeEngine.from_engine_args(engine_args)
- result = do_sample(engine)
- expected_output = {
- " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
- }
- assert result == expected_output
|