gguf_inference.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from huggingface_hub import hf_hub_download
  2. from aphrodite import LLM, SamplingParams
  3. def run_gguf_inference(model_path):
  4. PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501
  5. system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501
  6. # Sample prompts.
  7. prompts = [
  8. "How many helicopters can a human eat in one sitting?",
  9. "What's the future of AI?",
  10. ]
  11. prompts = [
  12. PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
  13. for prompt in prompts
  14. ]
  15. # Create a sampling params object.
  16. sampling_params = SamplingParams(temperature=0, max_tokens=128)
  17. # Create an LLM.
  18. llm = LLM(model=model_path,
  19. tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
  20. gpu_memory_utilization=0.95,
  21. quantization="gguf")
  22. outputs = llm.generate(prompts, sampling_params)
  23. # Print the outputs.
  24. for output in outputs:
  25. prompt = output.prompt
  26. generated_text = output.outputs[0].text
  27. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  28. if __name__ == "__main__":
  29. repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
  30. filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
  31. model = hf_hub_download(repo_id, filename=filename)
  32. run_gguf_inference(model)