Browse Source

Merge dev branch into main (#7)

AlpinDale 1 year ago
parent
commit
76b2e4a445

+ 1 - 0
.gitignore

@@ -9,4 +9,5 @@ repos
 build
 *.json
 dist/
+.VSCodeCounter
 

+ 44 - 20
README.md

@@ -17,7 +17,8 @@ Aphrodite builds upon and integrates the exceptional work from various projects,
 - [SkyPilot](https://github.com/skypilot-org/skypilot)
 - [OpenAI Python Library](https://github.com/openai/openai-python)
 
-<h2>Please note that Aphrodite only supports 16-bit HuggingFace models (no GGML or GPTQ). Please refer to the notes below for important information.</h2>
+:warning:
+**Please note that Aphrodite only supports 16-bit HuggingFace models (no GGML or GPTQ). Please refer to the notes below for important information.**
 
 ## Features
 
@@ -36,36 +37,27 @@ Aphrodite builds upon and integrates the exceptional work from various projects,
 
 ## Supported GPUs
 
-Basically, anything with a compute capability of 7.0 or higher. Here's a full list of supported consumer GPUs:
+Basically, anything with a compute capability of 6.0 or higher. Refer to this page for a full list of CUDA GPUs:
 
-| GPU     | CC  | GPU       | CC  | GPU     | CC  |
-| ------- | --- | --------- | --- | ------- | --- |
-| 2060    | 7.5 | 2070      | 7.5 | 2080    | 7.5 |
-| 2080 Ti | 7.5 | Titan RTX | 7.5 | 1650 Ti | 7.5 |
-| 3060    | 8.6 | 3060 Ti   | 8.6 | 3070    | 8.6 |
-| 3070 Ti | 8.6 | 3080      | 8.6 | 3080 Ti | 8.6 |
-| 3090    | 8.6 | 3090 Ti   | 8.6 | 4070 Ti | 8.9 |
-| 4080    | 8.9 | 4090      | 8.9 |         |     |
+[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus).
 
-> \* CC: Compute Capability
 
-Most datacenter/workstation GPUs are supported, so long as they have a compute capability of 7.0 or higher.
-
-If you're unsure, you can find out by opening a Python interpreter and running:
+Or, you can manually find out your GPU's Compute Capability by opening a Python interpreter and running:
 ```py
->>> import torch
+>>> import torch    # if you don't have `torch` installed, run `pip install torch` first
 >>> print(torch.cuda.get_device_capability())
 ```
 This should print something like this: `(7, 5)`, which would indicate a CC of 7.5
 
-If your GPU is not listed here or you do not meet the minimum CC, you will not be able to run Aphrodite.
+If you do not meet the minimum CC, you will not be able to run Aphrodite.
 
 ## Setting up the environment
+:grey_exclamation:
+**If you run into any problems, please refer to the common [Common Issues](#common-issues) section, or open an [Issue](https://github.com/PygmalionAI/aphrodite-engine/issues) if you can't find the answer there.**
 
 Aphrodite will require a slightly specialized environment to run, as the latest CUDA and GCC versions are not supported. You can use Conda to easily configure your environment.
 
 ### Install miniconda3
-If you run into any problems, please refer the common [Common Issues](#common-issues) section, or open an [Issue](https://github.com/PygmalionAI/aphrodite-engine/issues) if you can't find the answer there.
 
 ```sh
 $ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
@@ -78,10 +70,16 @@ You can either source your shell script (`. ~/.bashrc` or `. ~/.zshrc`) or resta
 ### Configuring the env for Aphrodite-engine
 ```sh
 $ conda config --set auto_activate_base false
-$ conda create -n aphrodite python=3.9
+$ conda create -n aphrodite python=3.10
 $ conda activate aphrodite
 $ conda install -c conda-forge cudatoolkit-dev gcc=11.3 gxx=11.3
 ```
+:warning: If you're using an NVIDIA H100 card, please run these install commands instead:
+```sh
+$ conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc=11.8
+$ pip install git+https://github.com/facebookresearch/xformers.git
+```
+
 The last command will take a long time, depending on your internet speed.
 
 Whenever you want to launch Aphrodite later on, make sure you run `conda activate aphrodite` first. The other steps outlined above are one-time only.
@@ -131,6 +129,31 @@ $ curl http://localhost:8000/v1/completions \
         "temperature": 0.8
     }'
 ```
+
+#### Chat API
+```sh
+$ curl -X POST "http://localhost:8000/v1/chat/completions" \
+  -H "Content-Type: application/json" \
+  -d '{
+    "messages": [
+      {
+        "role": "system",
+        "content": "Act out the scenario below in a fictional setting."
+      },
+      { "role": "assistant", "content": "[First Message]" },
+      { "role": "user", "content": "---user input---" },
+    ],
+    "model": "EleutherAI/pythia-70m",
+    "temperature": 0.9,
+    "max_tokens": 500,
+    "stream": false,
+    "presence_penalty": 0.7,
+    "frequency_penalty": 0.7,
+    "top_p": 1,
+    "top_k": -1,
+    "logit_bias": {}
+  }'
+```
 For the full list of request parameters, see [OpenAI Completions API reference](https://platform.openai.com/docs/api-reference/completions).
 
 ### Common Issues
@@ -144,7 +167,8 @@ $ export CUDA_HOME=/home/anon/miniconda3/envs/aphrodite
 
 Then run the installation command again.
 
-- `Cuda failure 'peer access is not supported between these two devices' [repeated 15x across cluster`
+- `Cuda failure 'peer access is not supported between these two devices' [repeated 15x across cluster]`
+
   
 This would be the last line in a very long error message. This happens if you're using a cluster of NVLinked GPUs and (possibly) using more than 2 of them at once. To fix this, run these two before starting the engine:
 
@@ -166,4 +190,4 @@ You've run out of swap space! Please pass the `--swap-space` followed by the amo
 3. You can view the full list of commands by running `python -m aphrodite.endpoints.openai.api_server --help`.
 
 ## Contributing
-We accept PRs! There will likely be a few typos or other errors we've failed to catch, so please let us know either via an issue or make a Pull Request.
+We accept PRs! There will likely be a few typos or other errors we've failed to catch, so please let us know either via an issue or by making a Pull Request.

+ 23 - 3
aphrodite/common/config.py

@@ -90,11 +90,31 @@ class ModelConfig:
         return self.hf_config.hidden_size
 
     def get_head_size(self) -> int:
+        # TODO: Probably not true for all models, but seems to be true for Llama and NeoX.
         return self.hf_config.hidden_size // self.hf_config.num_attention_heads
 
     def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
+        if getattr(self.hf_config, "multi_query", False):
+            return 1
+        if getattr(self.hf_config, "n_head_kv", None) is not None:
+            return self.hf_config.n_head_kv // parallel_config.tensor_parallel_size
+        if getattr(self.hf_config, "num_key_value_heads", None) is not None:
+            return self.hf_config.num_key_value_heads // parallel_config.tensor_parallel_size
         total_num_attention_heads = self.hf_config.num_attention_heads
         return total_num_attention_heads // parallel_config.tensor_parallel_size
+    
+    def get_max_model_len(self) -> int:
+        max_model_len = float("inf")
+        possible_keys = [
+            "max_sequence_length",
+            "max_seq_length",
+            "seq_len",
+        ]
+        for key in possible_keys:
+            max_len_key = getattr(self.hf_config, key, None)
+            if max_len_key is not None:
+                max_model_len = min(max_model_len, max_model_len)
+        return max_model_len
 
     def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
         total_num_hidden_layers = self.hf_config.num_hidden_layers
@@ -185,17 +205,17 @@ class SchedulerConfig:
             a single iteration.
         max_num_seqs: Maximum number of sequences to be processed in a single
             iteration.
-        max_seq_len: Maximum length of a sequence (including prompt and generated text).
+        max_model_len: Maximum length of a sequence (including prompt and generated text).
     """
     def __init__(
         self,
         max_num_batched_tokens: int,
         max_num_seqs: int,
-        max_seq_len: int,
+        max_model_len: int,
     ) -> None:
         self.max_num_batched_tokens = max_num_batched_tokens
         self.max_num_seqs = max_num_seqs
-        self.max_seq_len = max_seq_len
+        self.max_model_len = max_model_len
 
 _STR_DTYPE_TO_TORCH_DTYPE = {
     "half": torch.float16,

+ 44 - 0
aphrodite/common/logits.py

@@ -0,0 +1,44 @@
+from abc import ABC, abstractmethod
+import torch
+from typing import Dict
+
+
+class LogitsProcessor(ABC):
+
+    @abstractmethod
+    def __call__(self, logits: torch.tensor) -> torch.tensor:
+        pass
+
+
+class BiasLogitsProcessor(LogitsProcessor):
+    """This is to enable logit_bias in the OpenAI server.
+    biases is a dict where each value is -100 to 100
+      according to the OpenAI API docs.
+    Args:
+      biases: Dict ov values from -100 to 100 to scale the
+        probability of a token being generated.
+        Each key of the dict coresponds to the the token id.
+    """
+
+    def __init__(self, biases: Dict[int, float]):
+        self.biases = biases
+
+        if not biases:
+            return
+
+        self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
+        self.values = torch.tensor(list(self.biases.values()),
+                                   dtype=torch.long)
+
+    def __call__(self, logits):
+        if not self.biases:
+            return logits
+
+        values = self.values.to(logits.device)
+        keys = self.keys.to(logits.device)
+
+        update_factors = torch.where(values >= 0, 1 + (values / 100),
+                                     1 / (1 - (values / 100)))
+        logits[0, keys] *= update_factors
+
+        return logits

+ 5 - 0
aphrodite/common/sampling_params.py

@@ -1,5 +1,6 @@
 """Sampling parameters used for text generation."""
 from typing import List, Optional, Union
+from aphrodite.common.logits import LogitsProcessor
 
 _SAMPLING_EPS = 1e-5
 
@@ -39,6 +40,8 @@ class SamplingParams:
             tokens after the EOS token is generated.
         max_tokens: Maximum number of tokens to generate per output sequence.
         logprobs: Number of log probabilities to return per output token.
+        logits_processors: List of LogitsProcessors to change the probability
+            of token prediction at runtime.
     """
 
     def __init__(
@@ -55,6 +58,7 @@ class SamplingParams:
         ignore_eos: bool = False,
         max_tokens: int = 16,
         logprobs: Optional[int] = None,
+        logits_processors: List[LogitsProcessor] = None,
     ) -> None:
         self.n = n
         self.best_of = best_of if best_of is not None else n
@@ -73,6 +77,7 @@ class SamplingParams:
         self.ignore_eos = ignore_eos
         self.max_tokens = max_tokens
         self.logprobs = logprobs
+        self.logits_processors = logits_processors
 
         self._verify_args()
         if self.use_beam_search:

+ 39 - 28
aphrodite/endpoints/openai/api_server.py

@@ -11,7 +11,7 @@ from fastapi.exceptions import RequestValidationError
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import JSONResponse, StreamingResponse
 from fastchat.conversation import Conversation, SeparatorStyle
-from fastchat.model.model_adapter import get_conversation_template
+#from fastchat.model.model_adapter import get_conversation_template
 
 import uvicorn
 
@@ -29,6 +29,7 @@ from aphrodite.common.logger import init_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.utils import random_uuid
+from aphrodite.common.logits import BiasLogitsProcessor
 
 TIMEOUT_KEEP_ALIVE = 5 # seconds
 
@@ -60,7 +61,7 @@ async def check_model(request) -> Optional[JSONResponse]:
 
 
 async def get_gen_prompt(request) -> str:
-    conv = get_conversation_template(request.model)
+    #conv = get_conversation_template(request.model)
     conv = Conversation(
         name=conv.name,
         system=conv.system,
@@ -95,25 +96,26 @@ async def get_gen_prompt(request) -> str:
     return prompt
 
 
-async def check_length(request, prompt, model_config):
-    if hasattr(model_config.hf_config, "max_sequence_length"):
-        context_len = model_config.hf_config.max_sequence_length
-    elif hasattr(model_config.hf_config, "seq_length"):
-        context_len = model_config.hf_config.seq_length
-    elif hasattr(model_config.hf_config, "max_position_embeddings"):
-        context_len = model_config.hf_config.max_position_embeddings
-    elif hasattr(model_config.hf_config, "seq_length"):
-        context_len = model_config.hf_config.seq_length
-    else:
-        context_len = 2048
+async def check_length(request, prompt):
+    # if hasattr(model_config.hf_config, "max_sequence_length"):
+    #     context_len = model_config.hf_config.max_sequence_length
+    # elif hasattr(model_config.hf_config, "seq_length"):
+    #     context_len = model_config.hf_config.seq_length
+    # elif hasattr(model_config.hf_config, "max_position_embeddings"):
+    #     context_len = model_config.hf_config.max_position_embeddings
+    # elif hasattr(model_config.hf_config, "seq_length"):
+    #     context_len = model_config.hf_config.seq_length
+    # else:
+    #     context_len = 2048
+
 
     input_ids = tokenizer(prompt).input_ids
     token_num = len(input_ids)
 
-    if token_num + request.max_tokens > context_len:
+    if token_num + request.max_tokens > max_model_len:
         return create_error_response(
             HTTPStatus.BAD_REQUEST,
-            f"This model's maximum context length is {context_len} tokens. "
+            f"This model's maximum context length is {max_model_len} tokens. "
             f"However, you requested {request.max_tokens + token_num} tokens "
             f"({token_num} in the messages, "
             f"{request.max_tokens} in the completion). "
@@ -167,7 +169,6 @@ async def create_chat_completion(raw_request: Request):
 
     NOTE: Currently we do not support the following features:
         - function_call (Users should implement this by themselves)
-        - logit_bias (to be supported by Aphrodite engine)
     """
     request = ChatCompletionRequest(**await raw_request.json())
     logger.info(f"Received chat completion request: {request}")
@@ -176,16 +177,19 @@ async def create_chat_completion(raw_request: Request):
     if error_check_ret is not None:
         return error_check_ret
 
-    if request.logit_bias is not None:
-        # TODO: support logit_bias in Aphrodite engine.
-        return create_error_response(HTTPStatus.BAD_REQUEST,
-                                     "logit_bias is not currently supported")
-
     prompt = await get_gen_prompt(request)
-    error_check_ret = await check_length(request, prompt, engine_model_config)
+    error_check_ret = await check_length(request, prompt)
     if error_check_ret is not None:
         return error_check_ret
 
+    if not request.logit_bias:
+        logit_processors = []
+    else:
+        biases = dict(
+            map(lambda bias: (int(bias[0]), bias[1]),
+                request.logit_bias.items()))
+        logit_processors = [BiasLogitsProcessor(biases)]
+
     model_name = request.model
     request_id = f"cmpl-{random_uuid()}"
     created_time = int(time.time())
@@ -202,6 +206,7 @@ async def create_chat_completion(raw_request: Request):
             top_k=request.top_k,
             ignore_eos=request.ignore_eos,
             use_beam_search=request.use_beam_search,
+            logits_processors=logit_processors,
         )
     except ValueError as e:
         return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@@ -359,10 +364,13 @@ async def create_completion(raw_request: Request):
         return create_error_response(HTTPStatus.BAD_REQUEST,
                                      "suffix is not currently supported")
 
-    if request.logit_bias is not None:
-        # TODO: support logit_bias in Aphrodite engine.
-        return create_error_response(HTTPStatus.BAD_REQUEST,
-                                     "logit_bias is not currently supported")
+    if not request.logit_bias:
+        logit_processors = []
+    else:
+        logit_bias = dict(
+            map(lambda logit: (int(logit[0]), logit[1]),
+                request.logit_bias.items()))
+        logit_processors = [BiasLogitsProcessor(logit_bias)]
 
     model_name = request.model
     request_id = f"cmpl-{random_uuid()}"
@@ -392,6 +400,7 @@ async def create_completion(raw_request: Request):
             max_tokens=request.max_tokens,
             logprobs=request.logprobs,
             use_beam_search=request.use_beam_search,
+            logits_processors=logit_processors,
         )
     except ValueError as e:
         return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@@ -579,13 +588,15 @@ if __name__ == "__main__":
     engine_args = AsyncEngineArgs.from_cli_args(args)
     engine = AsyncAphrodite.from_engine_args(engine_args)
     engine_model_config = asyncio.run(engine.get_model_config())
+    max_model_len = engine_model_config.get_max_model_len()
 
     # A separate tokenizer to map token IDs to strings.
     tokenizer = get_tokenizer(engine_args.tokenizer,
-                              tokenizer_mode=engine_args.tokenizer_mode)
+                              tokenizer_mode=engine_args.tokenizer_mode,
+                              trust_remote_code=engine_args.trust_remote_code)
 
     uvicorn.run(app,
                 host=args.host,
                 port=args.port,
                 log_level="info",
-                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
+                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

+ 86 - 37
aphrodite/engine/aphrodite_engine.py

@@ -1,17 +1,24 @@
 import time
-from typing import Any, List, Optional
+from functools import partial
+from typing import Any, List, Optional, TYPE_CHECKING
 
 from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig
 from aphrodite.processing.scheduler import Scheduler
 from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.engine.ray_tools import DeviceID, initialize_cluster, ray
+from aphrodite.engine.ray_tools import initialize_cluster, ray, RayWorker
 from aphrodite.common.logger import init_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
 from aphrodite.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer
 from aphrodite.common.utils import Counter
-from aphrodite.task_handler.worker import Worker
+
+if ray:
+    from ray.air.util.torch_dist import init_torch_dist_process_group
+    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
+
+if TYPE_CHECKING:
+    from ray.util.placement_group import PlacementGroup
 
 
 logger = init_logger(__name__)
@@ -48,7 +55,7 @@ class AphroditeEngine:
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
         distributed_init_method: str,
-        stage_devices: List[List[DeviceID]],
+        placement_group: Optional["PlacementGroup"],
         log_stats: bool,
     ) -> None:
         logger.info(
@@ -76,28 +83,65 @@ class AphroditeEngine:
             trust_remote_code=model_config.trust_remote_code)
         self.seq_counter = Counter()
 
-        self.workers: List[Worker] = []
-        assert len(stage_devices) == 1, "Only support one stage for now"
-        for rank, node_resource, _ in stage_devices[0]:
-            worker_cls = Worker
-            if self.parallel_config.worker_use_ray:
-                worker_cls = ray.remote(
-                    num_cpus=0,
-                    num_gpus=1,
-                    resources={node_resource: 1e-3},
-                )(worker_cls).remote
-            
-            worker = worker_cls(
-                model_config,
-                parallel_config,
-                scheduler_config,
-                rank,
-                distributed_init_method,
-            )
-            self.workers.append(worker)
+        if self.parallel_config.worker_use_ray:
+            self._init_workers_ray(placement_group)
+        else:
+            self._init_workers(distributed_init_method)
+
         self._init_cache()
 
         self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
+
+    def _init_workers(self, distributed_init_method: str):
+        from aphrodite.task_handler.worker import Worker
+
+        assert self.parallel_config.world_size == 1, (
+            "Ray is required if parallel_config.world_size > 1.")
+        
+        self.workers : List[Worker] = []
+        worker = Worker(
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            0,
+            distributed_init_method,
+        )
+        self.workers.append(worker)
+        self._run_workers(
+            "init_model",
+            get_all_outputs=True,
+        )
+
+    def _init_workers_ray(self, placement_group: "PlacementGroup"):
+        from aphrodite.task_handler.worker import Worker
+
+        self.workers: List[Worker] = []
+        for bundle in placement_group.bundle_specs:
+            if not bundle.get("GPU", 0):
+                continue
+            worker = ray.remote(
+                num_cpus=0,
+                num_gpus=1,
+                scheduling_strategy=PlacementGroupSchedulingStrategy(
+                placement_group=placement_group,
+                placement_group_capture_child_tasks=True),
+            )(RayWorker).remote()
+            self.workers.append(worker)
+        
+        init_torch_dist_process_group(self.workers, backend="nccl")
+        self._run_workers("init_worker",
+                          get_all_outputs=True,
+                          worker_init_fn=lambda: Worker(
+                                self.model_config,
+                                self.parallel_config,
+                                self.scheduler_config,
+                                None,
+                                None,
+                          ))
+        self._run_workers(
+            "init_model",
+            get_all_outputs=True,
+        )
         
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)
@@ -134,11 +178,11 @@ class AphroditeEngine:
         engine_configs = engine_args.create_engine_configs()
         parallel_config = engine_configs[2]
         # Initialize the cluster.
-        distributed_init_method, devices = initialize_cluster(parallel_config)
+        distributed_init_method, placement_group = initialize_cluster(parallel_config)
         # Create the engine.
         engine = cls(*engine_configs,
                      distributed_init_method,
-                     devices,
+                     placement_group,
                      log_stats=not engine_args.disable_log_stats)
         return engine
 
@@ -240,15 +284,18 @@ class AphroditeEngine:
     def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
         for seq_group in seq_groups:
             for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
-                new_token, new_output_text = detokenize_incrementally(
-                    self.tokenizer,
-                    seq.output_tokens,
-                    seq.get_last_token_id(),
-                    skip_special_tokens=True,
-                )
-                if new_token is not None:
-                    seq.output_tokens.append(new_token)
-                    seq.output_text = new_output_text
+                last_token_id = seq.get_last_token_id()
+                if last_token_id is not None:
+                    new_token, new_output_text = detokenize_incrementally(
+                        self.tokenizer,
+                        seq.output_tokens,
+                        last_token_id,
+                        skip_special_tokens=True,
+                    )
+                    if new_token is not None:
+                        seq.output_tokens.append(new_token)
+                        seq.output_text = new_output_text
+
 
     def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
         """Stop the finished sequences."""
@@ -266,7 +313,7 @@ class AphroditeEngine:
                     continue
                 
                 if (seq.get_len() >=
-                        self.scheduler.scheduler_config.max_seq_len):
+                        self.scheduler_config.max_model_len):
                     self.scheduler.free_seq(
                         seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
                     continue
@@ -290,9 +337,11 @@ class AphroditeEngine:
         """Runs the given method on all workers."""
         all_outputs = []
         for worker in self.workers:
-            executor = getattr(worker, method)
+            # executor = getattr(worker, method)
             if self.parallel_config.worker_use_ray:
-                executor = executor.remote
+                executor = partial(worker.execute_method.remote, method)    # FIXME: will this break us?
+            else:
+                executor = getattr(worker, method)
 
             output = executor(*args, **kwargs)
             all_outputs.append(output)

+ 5 - 7
aphrodite/engine/args_tools.py

@@ -16,13 +16,13 @@ class EngineArgs:
     use_np_weights: bool = False
     use_dummy_weights: bool = False
     dtype: str = "auto"
-    seed: int = 42
+    seed: int = 0
     worker_use_ray: bool = False
     pipeline_parallel_size: int = 1
     tensor_parallel_size: int = 1
     block_size: int = 16
-    swap_space: int = 4 # in GiB
-    gpu_memory_utilization: float = 0.88
+    swap_space: int = 5 # in GiB
+    gpu_memory_utilization: float = 0.95
     max_num_batched_tokens: int = 2560
     max_num_seqs: int = 256
     disable_log_stats: bool = False
@@ -81,11 +81,10 @@ class EngineArgs:
         parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                          self.tensor_parallel_size,
                                          self.worker_use_ray)
-        model_max_len = getattr(model_config.hf_config,
+        max_model_len = getattr(model_config.hf_config,
                                 'max_position_embeddings', float('inf'))
-        max_seq_len = min(self.max_num_batched_tokens, model_max_len)
         scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
-                                           self.max_num_seqs, max_seq_len)
+                                           self.max_num_seqs, model_config.get_max_model_len())
         return model_config, cache_config, parallel_config, scheduler_config
 
 
@@ -102,4 +101,3 @@ class AsyncEngineArgs(EngineArgs):
         parser.add_argument('--engine-use-ray', action='store_true', help='use Ray to start the Aphrodite Engine in a separate process as the server process.')
         parser.add_argument('--disable-log-requests', action='store_true', help='disable logging requests')
         return parser
-

+ 2 - 2
aphrodite/engine/async_aphrodite.py

@@ -197,13 +197,13 @@ class AsyncAphrodite:
     def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncAphrodite":
         engine_configs = engine_args.create_engine_configs()
         parallel_config = engine_configs[2]
-        distributed_init_method, devices = initialize_cluster(
+        distributed_init_method, placement_group = initialize_cluster(
             parallel_config, engine_args.engine_use_ray)
         engine = cls(engine_args.worker_use_ray,
                      engine_args.engine_use_ray,
                      *engine_configs,
                      distributed_init_method,
-                     devices,
+                     placement_group,
                      log_requests=not engine_args.disable_log_requests,
                      log_stats=not engine_args.disable_log_stats)
         return engine

+ 59 - 62
aphrodite/engine/ray_tools.py

@@ -1,22 +1,47 @@
 """Ray for distributed multi-node inference: https://github.com/ray-project/ray"""
-import random
-from typing import List, Optional, Tuple
+import socket
+from typing import List, Optional, Tuple, TYPE_CHECKING
+
+from aphrodite.common.config import ParallelConfig
 
 try:
     import ray
+    from ray.air.util.torch_dist import TorchDistributedWorker
+    """Ray wrapper for aphrodite.task_handler.worker, allowing
+    worker to be lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
+    class RayWorker(TorchDistributedWorker):
+        def __init__(self) -> None:
+            self.worker = None
+        
+        def init_worker(self, worker_init_fn):
+            self.worker = worker_init_fn()
+
+        def __getattr__(self, name):
+            return getattr(self.worker, name)
+        
+        def execute_method(self, method, *args, **kwargs):
+            executor = getattr(self, method)
+            return executor(*args, **kwargs)
+        
 except ImportError:
     ray = None
+    TorchDistributedWorker = None
+    RayWorker = None
 
-from aphrodite.common.config import ParallelConfig
-
-DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
+if TYPE_CHECKING:
+    from ray.util.placement_group import PlacementGroup
 
+def get_open_port():
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        s.bind(("", 0))
+        return s.getsockname()[1]
+    
 
 def initialize_cluster(
     parallel_config: ParallelConfig,
     engine_use_ray: bool = False,
     ray_address: Optional[str] = None,
-) -> Tuple[str, List[List[DeviceID]]]:
+) -> Tuple[str, Optional["PlacementGroup"]]:
     """Initialize the distributed cluster probably with Ray.
 
     Args:
@@ -35,66 +60,38 @@ def initialize_cluster(
     if parallel_config.worker_use_ray or engine_use_ray:
         if ray is None:
             raise ImportError("Ray is not installed. Please install Ray to use distributed inference.")
-        ray.init(address=ray_address)
+        ray.init(address=ray_address, ignore_reinit_error=True)
 
     if not parallel_config.worker_use_ray:
-        port = random.randint(10000, 20000)
+        port = get_open_port()
         distributed_init_method = f"tcp://localhost:{port}"
-        all_stage_devices = [[(0, None, 0)]]
-        return distributed_init_method, all_stage_devices
-
-    # NOTE: We assume each node has the same number of GPUs.
-    valid_node_resources = []
-    num_devices_per_node = None
-    for node in ray.nodes():
-        if (not node['Alive']) or node['Resources']['GPU'] <= 0:
-            continue
-        if num_devices_per_node is None:
-            num_devices_per_node = node['Resources']['GPU']
-        else:
-            assert num_devices_per_node == node['Resources']['GPU'], (
-            "The number of GPUs per node is not uniform.")
-        for key in node['Resources']:
-            if key.startswith('node:'):
-                valid_node_resources.append(key)
+        return distributed_init_method, None
     
-    num_nodes = len(valid_node_resources)
-    if parallel_config.world_size > num_nodes * num_devices_per_node:
-        raise ValueError(
-            "The number of required GPUs exceeds the total number of "
-            "available GPUs.")
-    if parallel_config.tensor_parallel_size >= num_devices_per_node:
-        if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
-            raise ValueError(
-                "The number of tensor parallelism is not divisible by the "
-                "number of GPUs per node.")
+    current_placement_group = ray.util.get_current_placement_group()
+    if current_placement_group:
+            bundles = current_placement_group.bundle_specs
+            gpu_bundles = 0
+            for bundle in bundles:
+                assert bundle.get("GPU", 0) > 1, (
+                    "Placement group bundles cannot have more than 1 GPU")
+                if bundle.get("GPU", 0):
+                    gpu_bundles += 1
+            if parallel_config.world_size > gpu_bundles:
+                raise ValueError(
+                    "The number of required GPUs exceeds the total number of "
+                    "available GPUs in the placement group..")
     else:
-        if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
+        num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
+        if parallel_config.world_size > num_gpus_in_cluster:
             raise ValueError(
-                "The number of GPUs per node is not divisible by the number "
-                "of tensor parallelism.")
-        
-    # Let's assign the GPUs to pipeline stages
-    rank = 0
-    current_node_id = 0
-    current_device_id = 0
-    distributed_init_method = None
-    all_stage_devices = []
-
-    for _ in range(parallel_config.pipeline_parallel_size):
-        stage_devices = []
-        for _ in range(parallel_config.tensor_parallel_size):
-            node_resource = valid_node_resources[current_node_id]
-            stage_devices.append((rank, node_resource, current_device_id))
-            if distributed_init_method is None:
-                ip = node_resource.split("node:")[-1]
-                port = random.randint(10000, 20000)
-                distributed_init_method = f"tcp://{ip}:{port}"
-            rank += 1
-            current_device_id += 1
-            if current_device_id >= num_devices_per_node:
-                current_node_id += 1
-                current_device_id = 0
-        all_stage_devices.append(stage_devices)
+                "The number of required GPUs exceeds the total number of "
+                "available GPUs in the cluster.")
+        current_placement_group = ray.util.placement_group([{
+            "GPU": 1
+        }] * parallel_config.world_size)
+        # Wait until PlacementGroup is ready. This will block until
+        # all requested resources are available, and will timeout
+        # if they cannot be provisioned.
+        ray.get(current_placement_group.ready(), timeout=1800)
 
-    return distributed_init_method, all_stage_devices
+    return None, current_placement_group

+ 16 - 10
aphrodite/modeling/layers/attention.py

@@ -23,6 +23,10 @@ class PagedAttention(nn.Module):
     This class takes flattened 1D query, key, and value tensors as input. The input 1D tensors
     can be split into three parts: the prompt tokens, the generation tokens, and the paddings.
 
+    |<------------------------------------------num_valid_tokens------------------------------------------------>|
+    |<-------------num_prompt_tokens---------------->|<--------------num_generation_tokens (M)------------------>|
+    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
+
     The prompts might have different lengths, while the generation tokens always have length 1.
     The paddings are appended to make the input length a multiple of 8, which is desirable for
     Tensor cores.
@@ -93,7 +97,7 @@ class PagedAttention(nn.Module):
                                             self.num_queries_per_kv,
                                             dim=1)
 
-        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
+        # TODO: The unsqueeze op may incur some CPU overhead. Optimize.
         out = xops.memory_efficient_attention_forward(
             query.unsqueeze(0),
             key.unsqueeze(0),
@@ -103,7 +107,7 @@ class PagedAttention(nn.Module):
             scale=self.scale,
             op=self.attn_op,
         )
-        # TODO(woosuk): Unnecessary copy. Optimize.
+        # TODO: Unnecessary copy. Optimize.
         output.copy_(out.squeeze(0))
         return output
 
@@ -235,8 +239,9 @@ class PagedAttentionWithRoPE(PagedAttention):
         rotary_dim: int,
         max_position: int = 8192,
         base: int = 10000,
+        num_kv_heads: Optional[int] = None,
     ) -> None:
-        super().__init__(num_heads, head_size, scale)
+        super().__init__(num_heads, head_size, scale, num_kv_heads)
 
         # Create the cos and sin cache.
         inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
@@ -269,12 +274,12 @@ class PagedAttentionWithRoPE(PagedAttention):
 
         Args:
             positions: shape = [num_tokens]
-                        query: shape = [num_tokens, num_heads * head_size]
-            key: shape = [num_tokens, num_heads * head_size]
-            value: shape = [num_tokens, num_heads * head_size]
-            key_cache: shape = [num_blocks, num_heads, head_size/x,
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
                 block_size, x]
-            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
+            value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size]
             input_metadata: metadata for paged attention.
             cache_event: event to wait for the cache operations to finish.
 
@@ -358,7 +363,7 @@ class PagedAttentionWithALiBi(PagedAttention):
             value: shape = [num_prompt_tokens, num_heads, head_size]
             input_metadata: metadata for paged attention.
         """
-        # FIXME(woosuk): Because xformers does not support dynamic sequence
+        # FIXME: Because xformers does not support dynamic sequence
         # lengths with custom attention bias, we process each prompt one by
         # one. This is inefficient, especially when we have many short prompts.
         start = 0
@@ -373,7 +378,7 @@ class PagedAttentionWithALiBi(PagedAttention):
                 scale=self.scale,
                 op=self.attn_op,
             )
-            # TODO(woosuk): Unnecessary copy. Optimize.
+            # TODO: Unnecessary copy. Optimize.
             output[start:end].copy_(out.squeeze(0))
             start += prompt_len
         return output
@@ -402,6 +407,7 @@ class PagedAttentionWithALiBi(PagedAttention):
             query,
             key_cache,
             value_cache,
+            self.head_mapping,
             self.scale,
             input_metadata.block_tables,
             input_metadata.context_lens,

+ 15 - 0
aphrodite/modeling/layers/sampler.py

@@ -52,6 +52,7 @@ class Sampler(nn.Module):
         assert len(frequency_penalties) == logits.shape[0]
         logits = _apply_penalties(logits, output_tokens, presence_penalties,
                                   frequency_penalties, self.vocab_size)
+        logits = _apply_logits_processors(input_metadata, logits, output_tokens)
 
         temperatures = _get_temperatures(input_metadata)
         assert len(temperatures) == logits.shape[0]
@@ -123,6 +124,20 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
                 output_tokens.append(seq_data.output_token_ids)
     return output_tokens
 
+def _apply_logits_processors(
+    input_metadata: InputMetadata,
+    logits: torch.Tensor,
+    output_tokens: List[List[int]]
+) -> torch.Tensor:
+    for _, seq_group in enumerate(input_metadata.seq_groups):
+        _, sampling_params = seq_group
+        logits_processors = sampling_params.logits_processors
+
+        if logits_processors is not None:
+            for logits_processor in logits_processors:
+                logits = logits_processor(logits, output_tokens)
+
+    return logits
 
 def _apply_penalties(
     logits: torch.Tensor,

+ 0 - 2
aphrodite/modeling/megatron/__init__.py

@@ -1,8 +1,6 @@
 import aphrodite.modeling.megatron.parallel_state
 import aphrodite.modeling.megatron.tensor_parallel
 
-# Alias parallel_state as mpu, its legacy name
-mpu = parallel_state
 
 __all__ = [
     "parallel_state",

+ 40 - 20
aphrodite/modeling/models/llama.py

@@ -58,7 +58,7 @@ class LlamaMLP(nn.Module):
                                             bias=False, input_is_parallel=True,
                                             perform_initialization=False)
         if hidden_act != 'silu':
-            raise ValueError(f'Unsupported activation: {hidden_act}. Only silu is currently supported.')
+            raise ValueError(f'Unsupported activation: {hidden_act}. Only silu is currently supported for LLaMA.')
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
@@ -73,19 +73,26 @@ class LlamaAttention(nn.Module):
         self,
         hidden_size: int,
         num_heads: int,
+        num_kv_heads: int,
     ):
         super().__init__()
         self.hidden_size = hidden_size
-        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
+        tp_size = get_tensor_model_parallel_world_size()
         self.total_num_heads = num_heads
-        assert self.total_num_heads % tensor_model_parallel_world_size == 0
-        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        assert self.total_num_kv_heads % tp_size == 0
+        self.num_kv_heads = self.total_num_kv_heads // tp_size
         self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
 
         self.qkv_proj = ColumnParallelLinear(
             hidden_size,
-            3 * self.total_num_heads * self.head_dim,
+            (self.total_num_heads + 2 * self.total_num_kv_heads) *
+            self.head_dim,
             bias=False,
             gather_output=False,
             perform_initialization=False,
@@ -97,7 +104,8 @@ class LlamaAttention(nn.Module):
             input_is_parallel=True,
             perform_initialization=False,
         )
-        self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.scaling, rotary_dim=self.head_dim)
+        self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.scaling, 
+                                           rotary_dim=self.head_dim, num_kv_heads=self.num_kv_heads)
 
     def forward(
         self,
@@ -108,7 +116,7 @@ class LlamaAttention(nn.Module):
         cache_event: Optional[torch.cuda.Event],
     ) -> torch.Tensor:
         qkv, _ = self.qkv_proj(hidden_states)
-        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         k_cache, v_cache = kv_cache
         attn_output = self.attn(
             positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
@@ -124,6 +132,7 @@ class LlamaDecoderLayer(nn.Module):
         self.self_attn = LlamaAttention(
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
+            num_kv_heads=config.num_key_value_heads,
         )
         self.mlp = LlamaMLP(
             hidden_size=self.hidden_size,
@@ -169,8 +178,11 @@ class LlamaModel(nn.Module):
         self.vocab_size = config.vocab_size
 
         vocab_size = ((config.vocab_size + 63) // 64) * 64
-        self.embed_tokens = VocabParallelEmbedding(vocab_size, config.hidden_size, perform_initialization=False)
-        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.embed_tokens = VocabParallelEmbedding(vocab_size, config.hidden_size,
+                                                   perform_initialization=False)
+        self.layers = nn.ModuleList([
+            LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
+        ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
     def forward(
@@ -235,9 +247,18 @@ class LlamaForCausalLM(nn.Module):
                      model_name_or_path: str,
                      cache_dir: Optional[str] = None,
                      use_np_cache: bool = False):
-        tensor_model_parallel_world_size = (
-            get_tensor_model_parallel_world_size())
+        tp_size = get_tensor_model_parallel_world_size()
         tensor_model_parallel_rank = get_tensor_model_parallel_rank()
+        q_proj_shard_size = (self.config.hidden_size // tp_size)
+        kv_proj_shard_size = (self.config.hidden_size //
+                              self.config.num_attention_heads *
+                              self.config.num_key_value_heads // tp_size)
+        attention_weight_specs = [
+            ("q_proj", q_proj_shard_size, 0),
+            ("k_proj", kv_proj_shard_size, q_proj_shard_size),
+            ("v_proj", kv_proj_shard_size,
+             q_proj_shard_size + kv_proj_shard_size),
+        ]
         state_dict = self.state_dict()
 
         for name, loaded_weight in hf_model_weights_iterator(
@@ -248,7 +269,7 @@ class LlamaForCausalLM(nn.Module):
             if "embed_tokens" in name or "lm_head" in name:
                 param = state_dict[name]
                 padded_vocab_size = (param.shape[0] *
-                                     tensor_model_parallel_world_size)
+                                     tp_size)
                 num_extra_rows = padded_vocab_size - self.config.vocab_size
                 extra_rows = torch.empty(num_extra_rows,
                                          loaded_weight.shape[1])
@@ -256,18 +277,17 @@ class LlamaForCausalLM(nn.Module):
                 loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
 
             is_attention_weight = False
-            for stride_id, att_weight_name in enumerate(
-                ["q_proj", "k_proj", "v_proj"]):
-                if att_weight_name not in name:
+            for weight_name, shard_size, offset in attention_weight_specs:
+                if weight_name not in name:
                     continue
-                param = state_dict[name.replace(att_weight_name, "qkv_proj")]
-                shard_size = param.shape[0] // 3
+                param = state_dict[name.replace(weight_name, "qkv_proj")]
+
                 loaded_weight = loaded_weight[
                     shard_size * tensor_model_parallel_rank:shard_size *
-                    (tensor_model_parallel_rank + 1)]
-                param_slice = param.data[shard_size * stride_id:shard_size *
-                                         (stride_id + 1)]
+                                        (tensor_model_parallel_rank + 1)]
+                param_slice = param.data[offset:offset + shard_size]
                 assert param_slice.shape == loaded_weight.shape
+
                 param_slice.copy_(loaded_weight)
                 is_attention_weight = True
                 break

+ 5 - 3
aphrodite/processing/scheduler.py

@@ -182,11 +182,13 @@ class Scheduler:
                     break
 
                 num_prompt_tokens = seq_group.get_seqs()[0].get_len()
-                if num_prompt_tokens >= self.scheduler_config.max_seq_len:
+                if num_prompt_tokens > min(
+                    self.scheduler_config.max_model_len,
+                    self.scheduler_config.max_num_batched_tokens):
                     logger.warning(
                         f"Input prompt ({num_prompt_tokens} tokens) is too long"
                         " and exceeds limit of "
-                        f"{self.scheduler_config.max_seq_len}")
+                        f"{self.scheduler_config.max_model_len}")
                     for seq in seq_group.get_seqs():
                         seq.status = SequenceStatus.FINISHED_IGNORED
                     ignored_seq_groups.append(seq_group)
@@ -417,7 +419,7 @@ class Scheduler:
             """
             raise RuntimeError(
                 "Aborted due to the lack of CPU swap space. Please increase "
-                "the swap space to avoid this error. https://wiki.archlinux.org/title/Swap")
+                "the swap size with `--swap-space 4`")
         mapping = self.block_manager.swap_out(seq_group)
         blocks_to_swap_out.update(mapping)
         for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):

+ 45 - 19
aphrodite/task_handler/worker.py

@@ -1,5 +1,7 @@
+import os
+from typing import Dict, List, Tuple, Optional
 import torch
-from typing import Dict, List, Tuple, Any
+import torch.distributed
 
 from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig
 from aphrodite.modeling import get_model, InputMetadata, set_random_seed
@@ -23,8 +25,8 @@ class Worker:
         model_config: ModelConfig,
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
-        rank: int,
-        distributed_init_method: str,
+        rank: Optional[int] = None,
+        distributed_init_method: Optional[str] = None,
     ) -> None:
         self.model_config = model_config
         self.parallel_config = parallel_config
@@ -32,22 +34,32 @@ class Worker:
         self.rank = rank
         self.distributed_init_method = distributed_init_method
 
-        _init_distributed_environment(parallel_config, rank, distributed_init_method)
+        self.cache_config = None
+        self.block_size = None
+        self.cache_engine = None
+        self.cache_events = None
+        self.gpu_cache = None
+
+    def init_model(self):
+        os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
+        self.rank = self.rank if self.rank is not None else int(
+            os.getenv("RANK", "-1"))
+        local_rank = int(os.getenv("LOCAL_RANK", "0"))
+        self.device = torch.device(f"cuda:{local_rank}")
+        if self.rank < 0:
+            raise ValueError("Invalid or unspecified rank.")
+        torch.cuda.set_device(self.device)
+
+        _init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method)
 
         set_random_seed(self.model_config.seed)
-        self.model = get_model(model_config)
+        self.model = get_model(self.model_config)
         initialize_all_reduce_launcher(
             self.scheduler_config.max_num_batched_tokens,
             self.model_config.get_hidden_size(),
             self.model_config.dtype,
         )
 
-        # These will be initialize by self.init_cache_engine()
-        self.cache_config = None
-        self.block_size = None
-        self.cache_engine = None
-        self.cache_events = None
-        self.gpu_cache = None
 
     @torch.inference_mode()
     def profile_num_available_blocks(
@@ -269,16 +281,30 @@ class Worker:
 
 def _init_distributed_environment(
     parallel_config: ParallelConfig,
-    rank: int,
-    distributed_init_method: str,
+    rank: int, # TODO: this too?
+    distributed_init_method: Optional[str] = None,
 ) -> None:
     """Initialize the distributed environment."""
-    torch.distributed.init_process_group(
-        backend="nccl",
-        world_size=parallel_config.world_size,
-        rank=rank,
-        init_method=distributed_init_method,
-    )
+    if torch.distributed.is_initialized():
+        torch_world_size = torch.distributed.get_world_size()
+        if torch_world_size != parallel_config.world_size:
+            raise RuntimeError(
+                "torch.distributed is already initialized but the "
+                "torch world size doesn't match parallel_config.world_size "
+                f"({torch_world_size} vs. {parallel_config.world_size})")
+    elif not distributed_init_method:
+        raise ValueError(
+            "distributed_init_method must be set if torch.distributed "
+            "is not already initialized."
+        )
+    else:
+        torch.distributed.init_process_group(
+            backend="nccl",
+            world_size=parallel_config.world_size,
+            rank=rank,
+            init_method=distributed_init_method,
+        )
+
     torch.distributed.all_reduce(torch.zeros(1).cuda())
     initialize_model_parallel(parallel_config.tensor_parallel_size,
                               parallel_config.pipeline_parallel_size)

+ 53 - 0
chat.sh

@@ -0,0 +1,53 @@
+#!/bin/bash
+
+if [ $# -lt 2 ]; then
+  echo "Usage: bash chat_completion.sh <model_name> <max_tokens> [<temperature>] [<stop_sequence>]"
+  exit 1
+fi
+
+ENDPOINT="http://localhost:8000/v1/chat/completions"
+
+MODEL_NAME="$1"
+MAX_TOKENS="$2"
+TEMPERATURE="${3:-1.0}"  # Default temperature is 1.0 if not provided
+STOP_SEQUENCE="${4:-\\n###}"  # Default stop sequence is '\n###' if not provided
+
+echo "Enter your conversation ('q' to quit):"
+CONVERSATION=""
+
+while true; do
+  read -p "You: " USER_INPUT
+
+  if [ "$USER_INPUT" == "q" ]; then
+    echo "Exiting..."
+    exit 0
+  fi
+
+  # Append user input to the conversation
+  CONVERSATION="$CONVERSATION\n### Human: $USER_INPUT"
+
+  DATA=$(cat <<EOF
+{
+  "model": "$MODEL_NAME",
+  "messages": [
+    {"role": "system", "content": "You are a helpful assistant."},
+    {"role": "user", "content": "$CONVERSATION"}
+  ],
+  "max_tokens": $MAX_TOKENS,
+  "temperature": $TEMPERATURE,
+  "stop": ["$STOP_SEQUENCE"]
+}
+EOF
+)
+
+  RESPONSE=$(curl -s -X POST -H "Content-Type: application/json" \
+                            -d "$DATA" \
+                            $ENDPOINT)
+
+  AI_REPLY=$(echo "$RESPONSE" | jq -r '.choices[0].message.content')
+
+  # Remove any generated text after the stop sequence
+  AI_REPLY=$(echo "$AI_REPLY" | sed -n "/$STOP_SEQUENCE/q;p")
+
+  echo -e "\033[1;35mBot:\033[0m \033[1;32m$AI_REPLY\033[0m"
+done

+ 2 - 2
examples/gradio_server.py

@@ -10,7 +10,7 @@ def http_bot(prompt):
     pload = {
         "prompt": prompt,
         "stream": True,
-        "max_tokens": 128,
+        "max_tokens": 512,
     }
     response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
 
@@ -41,4 +41,4 @@ if __name__ == "__main__":
     demo = build_demo()
     demo.queue(concurrency_count=100).launch(server_name=args.host,
                                              server_port=args.port,
-                                             share=True)    
+                                             share=False)    

+ 0 - 42
examples/openai_client.py

@@ -1,42 +0,0 @@
-import argparse
-import openai
-
-openai.api_key = "sasuga"
-openai.api_base = "http://localhost:8000/v1"
-model = "PygmalionAI/pygmalion-350m"
-
-models = openai.Model.list()
-print("Models:", models)
-
-def get_completions(prompt, use_chat_completions):
-    if use_chat_completions:
-        completions = openai.Completion.create(
-            model=model,
-            message=[
-                {"role": "system", "content": "You are a helpful assistant."},
-                {"role": "user", "content": prompt}
-            ],
-            max_tokens=50,
-        )
-        return completions.choices[-1].message["content"]
-    else:
-        completion = openai.Completion.create(
-            model=model,
-            prompt=prompt,
-            echo=False,
-            n=2,
-            best_of=3,
-            logprobs=3,
-        )
-        return completion.choices[0].text.strip()
-    
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--use-chat-completions", action="store_true")
-    parser.add_argument("--prompt", type=str, default="A robot may injure a human being")
-    args = parser.parse_args()
-
-    completions = get_completions(args.prompt, args.use_chat_completions)
-
-    print("Completion result:")
-    print(completions)

+ 1 - 1
kernels/cache_kernels.cu

@@ -31,7 +31,7 @@ void swap_blocks(
 
   const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  // NOTE(woosuk): This can be slow if the number of blocks is large.
+  // NOTE: This can be slow if the number of blocks is large.
   for (const auto& pair : block_mapping) {
     int64_t src_block_number = pair.first;
     int64_t dst_block_number = pair.second;

+ 13 - 8
kernels/pos_encoding_kernels.cu

@@ -7,11 +7,12 @@ template<typename scalar_t>
 __global__ void rotary_embedding_neox_kernel(
   const int64_t* __restrict__ positions,        // [num_tokens]
   scalar_t* __restrict__ query,                 // [num_tokens, num_heads, head_size]
-  scalar_t* __restrict__ key,                   // [num_tokens, num_heads, head_size]
+  scalar_t* __restrict__ key,                   // [num_tokens, num_kv_heads, head_size]
   const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
   const int rot_dim,
   const int stride,
   const int num_heads,
+  const int num_kv_heads,
   const int head_size) {
   // Each thread block is responsible for one token.
   const int token_idx = blockIdx.x;
@@ -19,8 +20,8 @@ __global__ void rotary_embedding_neox_kernel(
   const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 
   const int embed_dim = rot_dim / 2;
-  const int n = num_heads * embed_dim;
-  for (int i = threadIdx.x; i < n; i += blockDim.x) {
+  const int nq = num_heads * embed_dim;
+  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
     const int head_idx = i / embed_dim;
     const int token_head = token_idx * stride + head_idx * head_size;
 
@@ -39,10 +40,12 @@ __global__ void rotary_embedding_neox_kernel(
     query[out_x] = q_x * cos - q_y * sin;
     query[out_y] = q_y * cos + q_x * sin;
 
-    const scalar_t k_x = key[token_head + x_index];
-    const scalar_t k_y = key[token_head + y_index];
-    key[out_x] = k_x * cos - k_y * sin;
-    key[out_y] = k_y * cos + k_x * sin;
+    if (head_idx < num_kv_heads) {
+      const scalar_t k_x = key[token_head + x_index];
+      const scalar_t k_y = key[token_head + y_index];
+      key[out_x] = k_x * cos - k_y * sin;
+      key[out_y] = k_y * cos + k_x * sin;
+    }
   }
 }
 
@@ -51,13 +54,14 @@ __global__ void rotary_embedding_neox_kernel(
 void rotary_embedding_neox(
   torch::Tensor& positions,         // [num_tokens]
   torch::Tensor& query,             // [num_tokens, num_heads * head_size]
-  torch::Tensor& key,               // [num_tokens, num_heads * head_size]
+  torch::Tensor& key,               // [num_tokens, num_kv_heads * head_size]
   int head_size,
   torch::Tensor& cos_sin_cache)     // [max_position, rot_dim]
 {
   int num_tokens = query.size(0);
   int rot_dim = cos_sin_cache.size(1);
   int num_heads = query.size(1) / head_size;
+  int num_kv_heads = key.size(1) / head_size;
   int stride = query.stride(0);
   TORCH_CHECK(stride == key.stride(0));
 
@@ -78,6 +82,7 @@ void rotary_embedding_neox(
         rot_dim,
         stride,
         num_heads,
+        num_kv_heads,
         head_size);
     });
 }

+ 3 - 3
requirements.txt

@@ -1,14 +1,14 @@
 ninja
 psutil
-ray
+ray >= 2.5.1
 sentencepiece
 numpy
 torch >= 2.0.0
-transformers >= 4.28.0
+transformers >= 4.31.0
 uvicorn
 openai # for fastapi's openai proxy emulation
 xformers >= 0.0.19
 mypy
 pytest
-fschat
+fschat >= 0.2.18
 pydantic < 2

+ 5 - 5
setup.py

@@ -13,7 +13,7 @@ ROOT_DIR = os.path.dirname(__file__)
 
 # Compiler flags.
 CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
-# TODO(woosuk): Should we use -O3?
+# TODO: Should we use -O3?
 NVCC_FLAGS = ["-O2", "-std=c++17"]
 
 ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
@@ -43,13 +43,13 @@ device_count = torch.cuda.device_count()
 compute_capabilities: Set[int] = set()
 for i in range(device_count):
     major, minor = torch.cuda.get_device_capability(i)
-    if major < 7:
+    if major < 6:
         raise RuntimeError(
-            "GPUs with compute capability less than 7.0 are not supported.")
+            "GPUs with compute capability less than 6.0 are not supported.")
     compute_capabilities.add(major * 10 + minor)
 # If no GPU is available, add all supported compute capabilities.
 if not compute_capabilities:
-    compute_capabilities = {70, 75, 80, 86, 90}
+    compute_capabilities = {60, 61, 65, 70, 75, 80, 86, 89, 90}
 # Add target compute capabilities to NVCC flags.
 for capability in compute_capabilities:
     NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
@@ -172,4 +172,4 @@ setuptools.setup(
     cmdclass={"build_ext": BuildExtension},
 )
 
-    
+    

+ 82 - 0
tests/latency.py

@@ -0,0 +1,82 @@
+import argparse
+import time
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from aphrodite import LLM, SamplingParams
+
+
+def main(args: argparse.Namespace):
+    print(args)
+
+    # Process all the requests in a single batch if possible.
+    # NOTE: If the request cannot be processed in a single batch,
+    # the engine will automatically process the request in multiple batches.
+    llm = LLM(
+        model=args.model,
+        tokenizer=args.tokenizer,
+        tensor_parallel_size=args.tensor_parallel_size,
+        swap_space=args.swap_space,
+        max_num_seqs=args.batch_size,
+        max_num_batched_tokens=args.batch_size * args.input_len,
+        trust_remote_code=args.trust_remote_code,
+    )
+
+    sampling_params = SamplingParams(
+        n=args.n,
+        temperature=0.0 if args.use_beam_search else 1.0,
+        top_p=1.0,
+        use_beam_search=args.use_beam_search,
+        ignore_eos=True,
+        max_tokens=args.output_len,
+    )
+    print(sampling_params)
+    dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
+
+    def run_to_completion(profile: bool = False):
+        if profile:
+            torch.cuda.cudart().cudaProfilerStart()
+        start_time = time.time()
+
+        llm.generate(prompt_token_ids=dummy_prompt_token_ids,
+                     sampling_params=sampling_params,
+                     use_tqdm=False)
+
+        end_time = time.time()
+        latency = end_time - start_time
+        if profile:
+            torch.cuda.cudart().cudaProfilerStop()
+        return latency
+
+    print("Warming up...")
+    run_to_completion(profile=False)
+
+    # Benchmark.
+    latencies = []
+    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
+        latencies.append(run_to_completion(profile=False))
+    print(f'Avg latency: {np.mean(latencies)} seconds')
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(
+        description='Benchmark the latency of processing a single batch of '
+                    'requests till completion.')
+    parser.add_argument('--model', type=str, default='facebook/opt-125m')
+    parser.add_argument('--tokenizer', type=str, default=None)
+    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
+    parser.add_argument('--input-len', type=int, default=32)
+    parser.add_argument('--output-len', type=int, default=128)
+    parser.add_argument('--batch-size', type=int, default=8)
+    parser.add_argument('--n', '-n', type=int, default=1,
+                        help='Number of generated sequences per prompt.')
+    parser.add_argument('--swap-space', type=int, default=4)
+    parser.add_argument('--use-beam-search', action='store_true')
+    parser.add_argument('--num-iters', type=int, default=3,
+                        help='Number of iterations to run.')
+    parser.add_argument('--trust-remote-code', action='store_true',
+                        help='trust remote code from huggingface')
+    args = parser.parse_args()
+    main(args)

+ 214 - 0
tests/serving.py

@@ -0,0 +1,214 @@
+import argparse
+import asyncio
+import json
+import random
+import time
+from typing import AsyncGenerator, List, Tuple
+
+import aiohttp
+import numpy as np
+from transformers import PreTrainedTokenizerBase
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
+
+# (prompt len, output len, latency)
+REQUEST_LATENCY: List[Tuple[int, int, float]] = []
+
+
+def sample_requests(
+    dataset_path: str,
+    num_requests: int,
+    tokenizer: PreTrainedTokenizerBase,
+) -> List[Tuple[str, int, int]]:
+    # Load the dataset.
+    with open(dataset_path) as f:
+        dataset = json.load(f)
+    # Filter out the conversations with less than 2 turns.
+    dataset = [
+        data for data in dataset
+        if len(data["conversations"]) >= 2
+    ]
+    # Only keep the first two turns of each conversation.
+    dataset = [
+        (data["conversations"][0]["value"], data["conversations"][1]["value"])
+        for data in dataset
+    ]
+
+    # Tokenize the prompts and completions.
+    prompts = [prompt for prompt, _ in dataset]
+    prompt_token_ids = tokenizer(prompts).input_ids
+    completions = [completion for _, completion in dataset]
+    completion_token_ids = tokenizer(completions).input_ids
+    tokenized_dataset = []
+    for i in range(len(dataset)):
+        output_len = len(completion_token_ids[i])
+        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
+
+    # Filter out too long sequences.
+    filtered_dataset: List[Tuple[str, int, int]] = []
+    for prompt, prompt_token_ids, output_len in tokenized_dataset:
+        prompt_len = len(prompt_token_ids)
+        if prompt_len < 4 or output_len < 4:
+            # Prune too short sequences.
+            continue
+        if prompt_len > 1024 or prompt_len + output_len > 2048:
+            # Prune too long sequences.
+            continue
+        filtered_dataset.append((prompt, prompt_len, output_len))
+
+    # Sample the requests.
+    sampled_requests = random.sample(filtered_dataset, num_requests)
+    return sampled_requests
+
+
+async def get_request(
+    input_requests: List[Tuple[str, int, int]],
+    request_rate: float,
+) -> AsyncGenerator[Tuple[str, int, int], None]:
+    input_requests = iter(input_requests)
+    for request in input_requests:
+        yield request
+
+        if request_rate == float("inf"):
+            # If the request rate is infinity, then we don't need to wait.
+            continue
+        # Sample the request interval from the exponential distribution.
+        interval = np.random.exponential(1.0 / request_rate)
+        # The next request will be sent after the interval.
+        await asyncio.sleep(interval)
+
+
+async def send_request(
+    backend: str,
+    api_url: str,
+    prompt: str,
+    prompt_len: int,
+    output_len: int,
+    best_of: int,
+    use_beam_search: bool,
+) -> None:
+    request_start_time = time.time()
+
+    headers = {"User-Agent": "Benchmark Client"}
+    if backend == "aphrodite":
+        pload = {
+            "prompt": prompt,
+            "n": 1,
+            "best_of": best_of,
+            "use_beam_search": use_beam_search,
+            "temperature": 0.0 if use_beam_search else 1.0,
+            "top_p": 1.0,
+            "max_tokens": output_len,
+            "ignore_eos": True,
+            "stream": False,
+        }
+    elif backend == "tgi":
+        assert not use_beam_search
+        params = {
+            "best_of": best_of,
+            "max_new_tokens": output_len,
+            "do_sample": True,
+        }
+        pload = {
+            "inputs": prompt,
+            "parameters": params,
+        }
+    else:
+        raise ValueError(f"Unknown backend: {backend}")
+
+    timeout = aiohttp.ClientTimeout(total=3 * 3600)
+    async with aiohttp.ClientSession(timeout=timeout) as session:
+        while True:
+            async with session.post(api_url, headers=headers, json=pload) as response:
+                chunks = []
+                async for chunk, _ in response.content.iter_chunks():
+                    chunks.append(chunk)
+            output = b"".join(chunks).decode("utf-8")
+            output = json.loads(output)
+
+            # Re-send the request if it failed.
+            if "error" not in output:
+                break
+
+    request_end_time = time.time()
+    request_latency = request_end_time - request_start_time
+    REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
+
+
+async def benchmark(
+    backend: str,
+    api_url: str,
+    input_requests: List[Tuple[str, int, int]],
+    best_of: int,
+    use_beam_search: bool,
+    request_rate: float,
+) -> None:
+    tasks: List[asyncio.Task] = []
+    async for request in get_request(input_requests, request_rate):
+        prompt, prompt_len, output_len = request
+        task = asyncio.create_task(send_request(backend, api_url, prompt,
+                                                prompt_len, output_len,
+                                                best_of, use_beam_search))
+        tasks.append(task)
+    await asyncio.gather(*tasks)
+
+
+def main(args: argparse.Namespace):
+    print(args)
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+
+    api_url = f"http://{args.host}:{args.port}/generate"
+    tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
+    input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
+
+    benchmark_start_time = time.time()
+    asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
+                          args.use_beam_search, args.request_rate))
+    benchmark_end_time = time.time()
+    benchmark_time = benchmark_end_time - benchmark_start_time
+    print(f"Total time: {benchmark_time:.2f} s")
+    print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
+
+    # Compute the latency statistics.
+    avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
+    print(f"Average latency: {avg_latency:.2f} s")
+    avg_per_token_latency = np.mean([
+        latency / (prompt_len + output_len)
+        for prompt_len, output_len, latency in REQUEST_LATENCY
+    ])
+    print(f"Average latency per token: {avg_per_token_latency:.2f} s")
+    avg_per_output_token_latency = np.mean([
+        latency / output_len
+        for _, output_len, latency in REQUEST_LATENCY
+    ])
+    print("Average latency per output token: "
+          f"{avg_per_output_token_latency:.2f} s")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Benchmark the online serving throughput.")
+    parser.add_argument("--backend", type=str, default="aphrodite",
+                        choices=["aphrodite", "tgi"])
+    parser.add_argument("--host", type=str, default="localhost")
+    parser.add_argument("--port", type=int, default=8000)
+    parser.add_argument("--dataset", type=str, required=True,
+                        help="Path to the dataset.")
+    parser.add_argument("--tokenizer", type=str, required=True,
+                        help="Name or path of the tokenizer.")
+    parser.add_argument("--best-of", type=int, default=1,
+                        help="Generates `best_of` sequences per prompt and "
+                             "returns the best one.")
+    parser.add_argument("--use-beam-search", action="store_true")
+    parser.add_argument("--num-prompts", type=int, default=1000,
+                        help="Number of prompts to process.")
+    parser.add_argument("--request-rate", type=float, default=float("inf"),
+                        help="Number of requests per second. If this is inf, "
+                             "then all the requests are sent at time 0. "
+                             "Otherwise, we use Poisson process to synthesize "
+                             "the request arrival times.")
+    parser.add_argument("--seed", type=int, default=0)
+    parser.add_argument('--trust-remote-code', action='store_true',
+                        help='trust remote code from huggingface')
+    args = parser.parse_args()
+    main(args)

+ 2 - 2
tests/throughput.py

@@ -84,7 +84,7 @@ def run_aphrodite(
             ignore_eos=True,
             max_tokens=output_len,
         )
-        # FIXME(woosuk): Do not use internal method.
+        # FIXME: Do not use internal method.
         llm._add_request(
             prompt=prompt,
             prompt_token_ids=None,
@@ -92,7 +92,7 @@ def run_aphrodite(
         )
 
     start = time.time()
-    # FIXME(woosuk): Do use internal method.
+    # FIXME: Do use internal method.
     llm._run_engine(use_tqdm=True)
     end = time.time()
     return end - start