1
0

neuron_inference.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os
  2. from aphrodite import LLM, SamplingParams
  3. # creates XLA hlo graphs for all the context length buckets.
  4. os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
  5. # creates XLA hlo graphs for all the token gen buckets.
  6. os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
  7. # Sample prompts.
  8. prompts = [
  9. "Once upon a time,",
  10. "In a galaxy far, far away,",
  11. "The quick brown fox jumps over the lazy dog.",
  12. "The meaning of life is",
  13. ]
  14. # Create a sampling params object.
  15. sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
  16. # Create an LLM.
  17. llm = LLM(
  18. model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
  19. max_num_seqs=8,
  20. # The max_model_len and block_size arguments are required to be same as
  21. # max sequence length when targeting neuron device.
  22. # Currently, this is a known limitation in continuous batching support
  23. # in transformers-neuronx.
  24. # TODO: Support paged-attention in transformers-neuronx.
  25. max_model_len=2048,
  26. block_size=2048,
  27. # The device can be automatically detected when AWS Neuron SDK is installed.
  28. # The device argument can be either unspecified for automated detection,
  29. # or explicitly assigned.
  30. device="neuron",
  31. tensor_parallel_size=2)
  32. # Generate texts from the prompts. The output is a list of RequestOutput objects
  33. # that contain the prompt, generated text, and other information.
  34. outputs = llm.generate(prompts, sampling_params)
  35. # Print the outputs.
  36. for output in outputs:
  37. prompt = output.prompt
  38. generated_text = output.outputs[0].text
  39. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")