slora_inference.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """
  2. This example shows how to use the multi-LoRA functionality for offline inference.
  3. Requires HuggingFace credentials for access to Llama2.
  4. """
  5. from typing import Optional, List, Tuple
  6. from huggingface_hub import snapshot_download
  7. from aphrodite import EngineArgs, AphroditeEngine, SamplingParams, RequestOutput
  8. from aphrodite.lora.request import LoRARequest
  9. def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
  10. """Create a list of test prompts with their sampling parameters.
  11. 2 requests for base model, 4 requests for the LoRA. We define 2
  12. different LoRA adapters (using the same model for demo purposes).
  13. Since we also set `max_loras=1`, the expectation is that the requests
  14. with the second LoRA adapter will be ran after all requests with the
  15. first adapter have finished.
  16. """
  17. return [
  18. ("A robot may not injure a human being",
  19. SamplingParams(temperature=0.0,
  20. # logprobs=1,
  21. prompt_logprobs=1,
  22. max_tokens=128), None),
  23. ("To be or not to be,",
  24. SamplingParams(temperature=0.8,
  25. top_k=5,
  26. presence_penalty=0.2,
  27. max_tokens=128), None),
  28. ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
  29. SamplingParams(temperature=0.0,
  30. # logprobs=1,
  31. prompt_logprobs=1,
  32. max_tokens=128,
  33. stop_token_ids=[32003]),
  34. LoRARequest("l2-lora-test", 1, lora_path)),
  35. ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
  36. SamplingParams(n=3,
  37. best_of=3,
  38. temperature=0.8,
  39. max_tokens=128,
  40. stop_token_ids=[32003]),
  41. LoRARequest("l2-lora-test", 1, lora_path)),
  42. ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
  43. SamplingParams(temperature=0.0,
  44. # logprobs=1,
  45. prompt_logprobs=1,
  46. max_tokens=128,
  47. stop_token_ids=[32003]),
  48. LoRARequest("l2-lora-test2", 2, lora_path)),
  49. ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
  50. SamplingParams(n=3,
  51. best_of=3,
  52. temperature=0.9,
  53. max_tokens=128,
  54. stop_token_ids=[32003]),
  55. LoRARequest("l2-lora-test", 1, lora_path)),
  56. ] # type: ignore
  57. def process_requests(engine: AphroditeEngine,
  58. test_prompts: List[Tuple[str, SamplingParams,
  59. Optional[LoRARequest]]]):
  60. """Continuously process a list of prompts and handle the outputs."""
  61. request_id = 0
  62. while test_prompts or engine.has_unfinished_requests():
  63. if test_prompts:
  64. prompt, sampling_params, lora_request = test_prompts.pop(0)
  65. engine.add_request(str(request_id),
  66. prompt,
  67. sampling_params,
  68. lora_request=lora_request)
  69. request_id += 1
  70. request_outputs: List[RequestOutput] = engine.step()
  71. for request_output in request_outputs:
  72. if request_output.finished:
  73. print(request_output)
  74. def initialize_engine() -> AphroditeEngine:
  75. """Initialize the AphroditeEngine."""
  76. # max_loras: controls the number of LoRAs that can be used in the same
  77. # batch. Larger numbers will cause higher memory usage, as each LoRA
  78. # slot requires its own preallocated tensor.
  79. # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
  80. # numbers will cause higher memory usage. If you know that all LoRAs will
  81. # use the same rank, it is recommended to set this as low as possible.
  82. # max_cpu_loras: controls the size of the CPU LoRA cache.
  83. engine_args = EngineArgs(model="NousResearch/Llama-2-7b-hf",
  84. enable_lora=True,
  85. max_loras=1,
  86. max_lora_rank=8,
  87. max_cpu_loras=2,
  88. max_num_seqs=256)
  89. return AphroditeEngine.from_engine_args(engine_args)
  90. def main():
  91. """Main function that sets up and runs the prompt processing."""
  92. engine = initialize_engine()
  93. lora_path = snapshot_download(repo_id="alpindale/l2-lora-test")
  94. test_prompts = create_test_prompts(lora_path)
  95. process_requests(engine, test_prompts)
  96. if __name__ == '__main__':
  97. main()