mlpspeculator_inference.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import gc
  2. import time
  3. from typing import List
  4. from aphrodite import LLM, SamplingParams
  5. def time_generation(llm: LLM, prompts: List[str],
  6. sampling_params: SamplingParams):
  7. # Generate texts from the prompts. The output is a list of RequestOutput
  8. # objects that contain the prompt, generated text, and other information.
  9. # Warmup first
  10. llm.generate(prompts, sampling_params)
  11. llm.generate(prompts, sampling_params)
  12. start = time.time()
  13. outputs = llm.generate(prompts, sampling_params)
  14. end = time.time()
  15. print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
  16. # Print the outputs.
  17. for output in outputs:
  18. generated_text = output.outputs[0].text
  19. print(f"text: {generated_text!r}")
  20. if __name__ == "__main__":
  21. template = (
  22. "Below is an instruction that describes a task. Write a response "
  23. "that appropriately completes the request.\n\n### Instruction:\n{}"
  24. "\n\n### Response:\n")
  25. # Sample prompts.
  26. prompts = [
  27. "Write about the president of the United States.",
  28. ]
  29. prompts = [template.format(prompt) for prompt in prompts]
  30. # Create a sampling params object.
  31. sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
  32. # Create an LLM without spec decoding
  33. llm = LLM(model="NousResearch/Meta-Llama-3.1-8B-Instruct",
  34. max_model_len=8192)
  35. print("Without speculation")
  36. time_generation(llm, prompts, sampling_params)
  37. del llm
  38. gc.collect()
  39. # Create an LLM with spec decoding
  40. llm = LLM(
  41. model="NousResearch/Meta-Llama-3.1-8B-Instruct",
  42. speculative_model="ibm-fms/llama3-8b-accelerator",
  43. # These are currently required for MLPSpeculator decoding
  44. use_v2_block_manager=True,
  45. enforce_eager=True,
  46. max_model_len=8192,
  47. )
  48. print("With speculation")
  49. time_generation(llm, prompts, sampling_params)