123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- """
- This example shows how to use the multi-LoRA functionality for offline inference.
- Requires HuggingFace credentials for access to Llama2.
- """
- from typing import Optional, List, Tuple
- from huggingface_hub import snapshot_download
- from aphrodite import EngineArgs, AphroditeEngine, SamplingParams, RequestOutput
- from aphrodite.lora.request import LoRARequest
- def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
- """Create a list of test prompts with their sampling parameters.
-
- 2 requests for base model, 4 requests for the LoRA. We define 2
- different LoRA adapters (using the same model for demo purposes).
- Since we also set `max_loras=1`, the expectation is that the requests
- with the second LoRA adapter will be ran after all requests with the
- first adapter have finished.
- """
- return [
- ("A robot may not injure a human being",
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128), None),
- ("To be or not to be,",
- SamplingParams(temperature=0.8,
- top_k=5,
- presence_penalty=0.2,
- max_tokens=128), None),
- ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128,
- stop_token_ids=[32003]),
- LoRARequest("sql-lora", 1, lora_path)),
- ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
- SamplingParams(n=3,
- best_of=3,
- temperature=0.8,
- max_tokens=128,
- stop_token_ids=[32003]),
- LoRARequest("sql-lora", 1, lora_path)),
- ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128,
- stop_token_ids=[32003]),
- LoRARequest("sql-lora2", 2, lora_path)),
- ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
- SamplingParams(n=3,
- best_of=3,
- temperature=0.9,
- max_tokens=128,
- stop_token_ids=[32003]),
- LoRARequest("sql-lora", 1, lora_path)),
- ]
- def process_requests(engine: AphroditeEngine,
- test_prompts: List[Tuple[str, SamplingParams,
- Optional[LoRARequest]]]):
- """Continuously process a list of prompts and handle the outputs."""
- request_id = 0
- while test_prompts or engine.has_unfinished_requests():
- if test_prompts:
- prompt, sampling_params, lora_request = test_prompts.pop(0)
- engine.add_request(str(request_id),
- prompt,
- sampling_params,
- lora_request=lora_request)
- request_id += 1
- request_outputs: List[RequestOutput] = engine.step()
- for request_output in request_outputs:
- if request_output.finished:
- print(request_output)
- def initialize_engine() -> AphroditeEngine:
- """Initialize the AphroditeEngine."""
- # max_loras: controls the number of LoRAs that can be used in the same
- # batch. Larger numbers will cause higher memory usage, as each LoRA
- # slot requires its own preallocated tensor.
- # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
- # numbers will cause higher memory usage. If you know that all LoRAs will
- # use the same rank, it is recommended to set this as low as possible.
- # max_cpu_loras: controls the size of the CPU LoRA cache.
- engine_args = EngineArgs(model="NousResearch/Llama-2-7b-hf",
- enable_lora=True,
- max_loras=1,
- max_lora_rank=8,
- max_cpu_loras=2,
- max_num_seqs=256)
- return AphroditeEngine.from_engine_args(engine_args)
- def main():
- """Main function that sets up and runs the prompt processing."""
- engine = initialize_engine()
- lora_path = snapshot_download(repo_id="alpindale/l2-lora-test")
- test_prompts = create_test_prompts(lora_path)
- process_requests(engine, test_prompts)
- if __name__ == '__main__':
- main()
|