slora_inference.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """
  2. This example shows how to use the multi-LoRA functionality for offline
  3. inference. Requires HuggingFace credentials for access to Llama2.
  4. """
  5. from typing import Optional, List, Tuple
  6. from huggingface_hub import snapshot_download
  7. from aphrodite import EngineArgs, AphroditeEngine, SamplingParams, RequestOutput
  8. from aphrodite.lora.request import LoRARequest
  9. def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
  10. """Create a list of test prompts with their sampling parameters.
  11. 2 requests for base model, 4 requests for the LoRA. We define 2
  12. different LoRA adapters (using the same model for demo purposes).
  13. Since we also set `max_loras=1`, the expectation is that the requests
  14. with the second LoRA adapter will be ran after all requests with the
  15. first adapter have finished.
  16. """
  17. return [
  18. (
  19. "A robot may not injure a human being",
  20. SamplingParams(
  21. temperature=0.0,
  22. # logprobs=1,
  23. prompt_logprobs=1,
  24. max_tokens=128),
  25. None),
  26. ("To be or not to be,",
  27. SamplingParams(temperature=0.8,
  28. top_k=5,
  29. presence_penalty=0.2,
  30. max_tokens=128), None),
  31. (
  32. """[user] Write a SQL query to answer the question based on the
  33. table schema.\n\n context: CREATE TABLE table_name_74
  34. (icao VARCHAR, airport VARCHAR)\n\n
  35. question: Name the ICAO for lilongwe
  36. international airport [/user] [assistant]""",
  37. SamplingParams(
  38. temperature=0.0,
  39. # logprobs=1,
  40. prompt_logprobs=1,
  41. max_tokens=128,
  42. stop_token_ids=[32003]),
  43. LoRARequest("l2-lora-test", 1, lora_path)),
  44. ("""[user] Write a SQL query to answer the question based on the table
  45. schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR,
  46. elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector
  47. what is under nationality? [/user] [assistant]""",
  48. SamplingParams(n=3,
  49. best_of=3,
  50. temperature=0.8,
  51. max_tokens=128,
  52. stop_token_ids=[32003]),
  53. LoRARequest("l2-lora-test", 1, lora_path)),
  54. (
  55. """[user] Write a SQL query to answer the question based on the
  56. table schema.\n\n context: CREATE TABLE table_name_74 (icao
  57. VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe
  58. international airport [/user] [assistant]""",
  59. SamplingParams(
  60. temperature=0.0,
  61. # logprobs=1,
  62. prompt_logprobs=1,
  63. max_tokens=128,
  64. stop_token_ids=[32003]),
  65. LoRARequest("l2-lora-test2", 2, lora_path)),
  66. ("""[user] Write a SQL query to answer the question based on the table
  67. schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR,
  68. elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector
  69. what is under nationality? [/user] [assistant]""",
  70. SamplingParams(n=3,
  71. best_of=3,
  72. temperature=0.9,
  73. max_tokens=128,
  74. stop_token_ids=[32003]),
  75. LoRARequest("l2-lora-test", 1, lora_path)),
  76. ] # type: ignore
  77. def process_requests(engine: AphroditeEngine,
  78. test_prompts: List[Tuple[str, SamplingParams,
  79. Optional[LoRARequest]]]):
  80. """Continuously process a list of prompts and handle the outputs."""
  81. request_id = 0
  82. while test_prompts or engine.has_unfinished_requests():
  83. if test_prompts:
  84. prompt, sampling_params, lora_request = test_prompts.pop(0)
  85. engine.add_request(str(request_id),
  86. prompt,
  87. sampling_params,
  88. lora_request=lora_request)
  89. request_id += 1
  90. request_outputs: List[RequestOutput] = engine.step()
  91. for request_output in request_outputs:
  92. if request_output.finished:
  93. print(request_output)
  94. def initialize_engine() -> AphroditeEngine:
  95. """Initialize the AphroditeEngine."""
  96. # max_loras: controls the number of LoRAs that can be used in the same
  97. # batch. Larger numbers will cause higher memory usage, as each LoRA
  98. # slot requires its own preallocated tensor.
  99. # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
  100. # numbers will cause higher memory usage. If you know that all LoRAs will
  101. # use the same rank, it is recommended to set this as low as possible.
  102. # max_cpu_loras: controls the size of the CPU LoRA cache.
  103. engine_args = EngineArgs(model="NousResearch/Llama-2-7b-hf",
  104. enable_lora=True,
  105. max_loras=1,
  106. max_lora_rank=8,
  107. max_cpu_loras=2,
  108. max_num_seqs=256)
  109. return AphroditeEngine.from_engine_args(engine_args)
  110. def main():
  111. """Main function that sets up and runs the prompt processing."""
  112. engine = initialize_engine()
  113. lora_path = snapshot_download(repo_id="alpindale/l2-lora-test")
  114. test_prompts = create_test_prompts(lora_path)
  115. process_requests(engine, test_prompts)
  116. if __name__ == '__main__':
  117. main()