1
0

slora_inference.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """
  2. This example shows how to use the multi-LoRA functionality for offline
  3. inference. Requires HuggingFace credentials for access to Llama2.
  4. """
  5. from typing import List, Optional, Tuple
  6. from huggingface_hub import snapshot_download
  7. from aphrodite import (AphroditeEngine, EngineArgs, RequestOutput,
  8. SamplingParams)
  9. from aphrodite.lora.request import LoRARequest
  10. def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
  11. """Create a list of test prompts with their sampling parameters.
  12. 2 requests for base model, 4 requests for the LoRA. We define 2
  13. different LoRA adapters (using the same model for demo purposes).
  14. Since we also set `max_loras=1`, the expectation is that the requests
  15. with the second LoRA adapter will be ran after all requests with the
  16. first adapter have finished.
  17. """
  18. return [
  19. (
  20. "A robot may not injure a human being",
  21. SamplingParams(
  22. temperature=0.0,
  23. # logprobs=1,
  24. prompt_logprobs=1,
  25. max_tokens=128),
  26. None),
  27. ("To be or not to be,",
  28. SamplingParams(temperature=0.8,
  29. top_k=5,
  30. presence_penalty=0.2,
  31. max_tokens=128), None),
  32. (
  33. """[user] Write a SQL query to answer the question based on the
  34. table schema.\n\n context: CREATE TABLE table_name_74
  35. (icao VARCHAR, airport VARCHAR)\n\n
  36. question: Name the ICAO for lilongwe
  37. international airport [/user] [assistant]""",
  38. SamplingParams(
  39. temperature=0.0,
  40. # logprobs=1,
  41. prompt_logprobs=1,
  42. max_tokens=128,
  43. stop_token_ids=[32003]),
  44. LoRARequest("l2-lora-test", 1, lora_path)),
  45. ("""[user] Write a SQL query to answer the question based on the table
  46. schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR,
  47. elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector
  48. what is under nationality? [/user] [assistant]""",
  49. SamplingParams(n=3,
  50. best_of=3,
  51. temperature=0.8,
  52. max_tokens=128,
  53. stop_token_ids=[32003]),
  54. LoRARequest("l2-lora-test", 1, lora_path)),
  55. (
  56. """[user] Write a SQL query to answer the question based on the
  57. table schema.\n\n context: CREATE TABLE table_name_74 (icao
  58. VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe
  59. international airport [/user] [assistant]""",
  60. SamplingParams(
  61. temperature=0.0,
  62. # logprobs=1,
  63. prompt_logprobs=1,
  64. max_tokens=128,
  65. stop_token_ids=[32003]),
  66. LoRARequest("l2-lora-test2", 2, lora_path)),
  67. ("""[user] Write a SQL query to answer the question based on the table
  68. schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR,
  69. elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector
  70. what is under nationality? [/user] [assistant]""",
  71. SamplingParams(n=3,
  72. best_of=3,
  73. temperature=0.9,
  74. max_tokens=128,
  75. stop_token_ids=[32003]),
  76. LoRARequest("l2-lora-test", 1, lora_path)),
  77. ] # type: ignore
  78. def process_requests(engine: AphroditeEngine,
  79. test_prompts: List[Tuple[str, SamplingParams,
  80. Optional[LoRARequest]]]):
  81. """Continuously process a list of prompts and handle the outputs."""
  82. request_id = 0
  83. while test_prompts or engine.has_unfinished_requests():
  84. if test_prompts:
  85. prompt, sampling_params, lora_request = test_prompts.pop(0)
  86. engine.add_request(str(request_id),
  87. prompt,
  88. sampling_params,
  89. lora_request=lora_request)
  90. request_id += 1
  91. request_outputs: List[RequestOutput] = engine.step()
  92. for request_output in request_outputs:
  93. if request_output.finished:
  94. print(request_output)
  95. def initialize_engine() -> AphroditeEngine:
  96. """Initialize the AphroditeEngine."""
  97. # max_loras: controls the number of LoRAs that can be used in the same
  98. # batch. Larger numbers will cause higher memory usage, as each LoRA
  99. # slot requires its own preallocated tensor.
  100. # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
  101. # numbers will cause higher memory usage. If you know that all LoRAs will
  102. # use the same rank, it is recommended to set this as low as possible.
  103. # max_cpu_loras: controls the size of the CPU LoRA cache.
  104. engine_args = EngineArgs(model="NousResearch/Llama-2-7b-hf",
  105. enable_lora=True,
  106. max_loras=1,
  107. max_lora_rank=8,
  108. max_cpu_loras=2,
  109. max_num_seqs=256)
  110. return AphroditeEngine.from_engine_args(engine_args)
  111. def main():
  112. """Main function that sets up and runs the prompt processing."""
  113. engine = initialize_engine()
  114. lora_path = snapshot_download(repo_id="alpindale/l2-lora-test")
  115. test_prompts = create_test_prompts(lora_path)
  116. process_requests(engine, test_prompts)
  117. if __name__ == '__main__':
  118. main()