neuron_int8_quantization.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. from aphrodite import LLM, SamplingParams
  3. # creates XLA hlo graphs for all the context length buckets.
  4. os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
  5. # creates XLA hlo graphs for all the token gen buckets.
  6. os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
  7. # Quantizes neuron model weight to int8 ,
  8. # The default config for quantization is int8 dtype.
  9. os.environ["NEURON_QUANT_DTYPE"] = "s8"
  10. # Sample prompts.
  11. prompts = [
  12. "Hello, my name is",
  13. "The president of the United States is",
  14. "The capital of France is",
  15. "The future of AI is",
  16. ]
  17. # Create a sampling params object.
  18. sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
  19. # Create an LLM.
  20. llm = LLM(
  21. model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
  22. max_num_seqs=8,
  23. # The max_model_len and block_size arguments are required to be same as
  24. # max sequence length when targeting neuron device.
  25. # Currently, this is a known limitation in continuous batching support
  26. # in transformers-neuronx.
  27. # TODO(liangfu): Support paged-attention in transformers-neuronx.
  28. max_model_len=2048,
  29. block_size=2048,
  30. # The device can be automatically detected when AWS Neuron SDK is installed.
  31. # The device argument can be either unspecified for automated detection,
  32. # or explicitly assigned.
  33. device="neuron",
  34. quantization="neuron_quant",
  35. override_neuron_config={
  36. "cast_logits_dtype": "bfloat16",
  37. },
  38. tensor_parallel_size=2,
  39. )
  40. # Generate texts from the prompts. The output is a list of RequestOutput objects
  41. # that contain the prompt, generated text, and other information.
  42. outputs = llm.generate(prompts, sampling_params)
  43. # Print the outputs.
  44. for output in outputs:
  45. prompt = output.prompt
  46. generated_text = output.outputs[0].text
  47. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")