lora_async_aphrodite.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. """
  2. This example shows how to use the multi-LoRA functionality for offline
  3. inference. Requires HuggingFace credentials for access to Llama2.
  4. """
  5. import asyncio
  6. from typing import List, Optional, Tuple
  7. from aphrodite import AsyncAphrodite, AsyncEngineArgs, SamplingParams
  8. from aphrodite.lora.request import LoRARequest
  9. def create_test_prompts(
  10. lora_path: str
  11. ) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
  12. """Create a list of test prompts with their sampling parameters.
  13. 2 requests for base model, 4 requests for the LoRA. We define 2
  14. different LoRA adapters (using the same model for demo purposes).
  15. Since we also set `max_loras=1`, the expectation is that the requests
  16. with the second LoRA adapter will be ran after all requests with the
  17. first adapter have finished.
  18. """
  19. return [
  20. (
  21. "A robot may not injure a human being",
  22. SamplingParams(
  23. temperature=0.0,
  24. # logprobs=1,
  25. prompt_logprobs=1,
  26. max_tokens=128),
  27. None),
  28. ("To be or not to be,",
  29. SamplingParams(temperature=0.8,
  30. top_k=5,
  31. presence_penalty=0.2,
  32. max_tokens=128), None),
  33. (
  34. """[user] Write a SQL query to answer the question based on the
  35. table schema.\n\n context: CREATE TABLE table_name_74
  36. (icao VARCHAR, airport VARCHAR)\n\n
  37. question: Name the ICAO for lilongwe
  38. international airport [/user] [assistant]""",
  39. SamplingParams(
  40. temperature=0.0,
  41. # logprobs=1,
  42. prompt_logprobs=1,
  43. max_tokens=128,
  44. stop_token_ids=[32003]),
  45. LoRARequest(
  46. lora_name="l2-lora-test",
  47. lora_int_id=1,
  48. lora_path=lora_path
  49. )),
  50. ("""[user] Write a SQL query to answer the question based on the table
  51. schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR,
  52. elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector
  53. what is under nationality? [/user] [assistant]""",
  54. SamplingParams(n=3,
  55. best_of=3,
  56. temperature=0.8,
  57. max_tokens=128,
  58. stop_token_ids=[32003]),
  59. LoRARequest(
  60. lora_name="l2-lora-test",
  61. lora_int_id=1,
  62. lora_path=lora_path
  63. )),
  64. (
  65. """[user] Write a SQL query to answer the question based on the
  66. table schema.\n\n context: CREATE TABLE table_name_74 (icao
  67. VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe
  68. international airport [/user] [assistant]""",
  69. SamplingParams(
  70. temperature=0.0,
  71. # logprobs=1,
  72. prompt_logprobs=1,
  73. max_tokens=128,
  74. stop_token_ids=[32003]),
  75. LoRARequest(
  76. lora_name="l2-lora-test2",
  77. lora_int_id=2,
  78. lora_path=lora_path
  79. )),
  80. ("""[user] Write a SQL query to answer the question based on the table
  81. schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR,
  82. elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector
  83. what is under nationality? [/user] [assistant]""",
  84. SamplingParams(n=3,
  85. best_of=3,
  86. temperature=0.9,
  87. max_tokens=128,
  88. stop_token_ids=[32003]),
  89. LoRARequest(
  90. lora_name="l2-lora-test",
  91. lora_int_id=1,
  92. lora_path=lora_path
  93. )),
  94. ] # type: ignore
  95. async def process_requests(engine: AsyncAphrodite,
  96. test_prompts: List[Tuple[str, SamplingParams,
  97. Optional[LoRARequest]]]):
  98. """Continuously process a list of prompts and handle the outputs."""
  99. request_id = 0
  100. active_requests = []
  101. for prompt, sampling_params, lora_request in test_prompts:
  102. request_generator = engine.generate(
  103. prompt,
  104. sampling_params,
  105. str(request_id),
  106. lora_request=lora_request
  107. )
  108. active_requests.append(request_generator)
  109. request_id += 1
  110. # Process all requests
  111. for request_generator in active_requests:
  112. # Don't await the generator itself, just iterate over it
  113. async for request_output in request_generator:
  114. if request_output.finished:
  115. print(request_output)
  116. def initialize_engine() -> AsyncAphrodite:
  117. """Initialize the AsyncAphrodite."""
  118. # Function remains unchanged as it's just initialization
  119. engine_args = AsyncEngineArgs(model="NousResearch/Llama-2-7b-hf",
  120. enable_lora=True,
  121. max_loras=1,
  122. max_lora_rank=8,
  123. max_cpu_loras=2,
  124. max_num_seqs=256)
  125. return AsyncAphrodite.from_engine_args(engine_args)
  126. async def main():
  127. """Main function that sets up and runs the prompt processing."""
  128. engine = initialize_engine()
  129. test_prompts = create_test_prompts("alpindale/l2-lora-test")
  130. await process_requests(engine, test_prompts)
  131. if __name__ == '__main__':
  132. asyncio.run(main())