ray_distributed_inference.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. """
  2. This example shows how to use Ray Data for running offline batch inference
  3. distributively on a multi-nodes cluster.
  4. Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
  5. """
  6. from typing import Any, Dict, List
  7. import numpy as np
  8. import ray
  9. from packaging.version import Version
  10. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  11. from aphrodite import LLM, SamplingParams
  12. assert Version(ray.__version__) >= Version(
  13. "2.22.0"), "Ray version must be at least 2.22.0"
  14. # Create a sampling params object.
  15. sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
  16. # Set tensor parallelism per instance.
  17. tensor_parallel_size = 1
  18. # Set number of instances. Each instance will use tensor_parallel_size GPUs.
  19. num_instances = 1
  20. # Create a class to do batch inference.
  21. class LLMPredictor:
  22. def __init__(self):
  23. # Create an LLM.
  24. self.llm = LLM(model="NousResearch/Meta-Llama-3.1-8B-Instruct",
  25. tensor_parallel_size=tensor_parallel_size)
  26. def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
  27. # Generate texts from the prompts.
  28. # The output is a list of RequestOutput objects that contain the prompt,
  29. # generated text, and other information.
  30. outputs = self.llm.generate(batch["text"], sampling_params)
  31. prompt: List[str] = []
  32. generated_text: List[str] = []
  33. for output in outputs:
  34. prompt.append(output.prompt)
  35. generated_text.append(' '.join([o.text for o in output.outputs]))
  36. return {
  37. "prompt": prompt,
  38. "generated_text": generated_text,
  39. }
  40. # Read one text file from S3. Ray Data supports reading multiple files
  41. # from cloud storage (such as JSONL, Parquet, CSV, binary format).
  42. ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
  43. # For tensor_parallel_size > 1, we need to create placement groups for
  44. # Aphrodite to use. Every actor has to have its own placement group.
  45. def scheduling_strategy_fn():
  46. # One bundle per tensor parallel worker
  47. pg = ray.util.placement_group(
  48. [{
  49. "GPU": 1,
  50. "CPU": 1
  51. }] * tensor_parallel_size,
  52. strategy="STRICT_PACK",
  53. )
  54. return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
  55. pg, placement_group_capture_child_tasks=True))
  56. resources_kwarg: Dict[str, Any] = {}
  57. if tensor_parallel_size == 1:
  58. # For tensor_parallel_size == 1, we simply set num_gpus=1.
  59. resources_kwarg["num_gpus"] = 1
  60. else:
  61. # Otherwise, we have to set num_gpus=0 and provide
  62. # a function that will create a placement group for
  63. # each instance.
  64. resources_kwarg["num_gpus"] = 0
  65. resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
  66. # Apply batch inference for all input data.
  67. ds = ds.map_batches(
  68. LLMPredictor,
  69. # Set the concurrency to the number of LLM instances.
  70. concurrency=num_instances,
  71. # Specify the batch size for inference.
  72. batch_size=32,
  73. **resources_kwarg,
  74. )
  75. # Peek first 10 results.
  76. # NOTE: This is for local testing and debugging. For production use case,
  77. # one should write full result out as shown below.
  78. outputs = ds.take(limit=10)
  79. for output in outputs:
  80. prompt = output["prompt"]
  81. generated_text = output["generated_text"]
  82. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  83. # Write inference output data out as Parquet files to S3.
  84. # Multiple files would be written to the output destination,
  85. # and each task would write one or more files separately.
  86. #
  87. # ds.write_parquet("s3://<your-output-bucket>")