Browse Source

chore: backlogs 1 (#191)

AlpinDale 1 year ago
parent
commit
b9b295d74e

+ 10 - 18
aphrodite/common/config.py

@@ -95,23 +95,20 @@ class ModelConfig:
         supported_load_format = [
         supported_load_format = [
             "auto", "pt", "safetensors", "npcache", "dummy"
             "auto", "pt", "safetensors", "npcache", "dummy"
         ]
         ]
-        rocm_not_supported_load_format = ["safetensors"]
+        rocm_not_supported_load_format = []
         if load_format not in supported_load_format:
         if load_format not in supported_load_format:
             raise ValueError(
             raise ValueError(
                 f"Unknown load format: {self.load_format}. Must be one of "
                 f"Unknown load format: {self.load_format}. Must be one of "
                 "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
                 "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
-        if is_hip():
-            if load_format in ["safetensors"]:
-                rocm_supported_load_format = [
-                    f for f in supported_load_format
-                    if (f not in rocm_not_supported_load_format)
-                ]
-                raise ValueError(
-                    f"load format {load_format} is not supported on ROCm. "
-                    f"Must be one of {rocm_supported_load_format}.")
-            # force ROCm to load from pt weights if nothing is set
-            if load_format == "auto":
-                load_format = "pt"
+        if is_hip() and load_format in rocm_not_supported_load_format:
+            rocm_supported_load_format = [
+                f for f in supported_load_format
+                if (f not in rocm_not_supported_load_format)
+            ]
+            raise ValueError(
+                f"load format \'{load_format}\' is not supported in ROCm. "
+                f"Supported load format are "
+                f"{rocm_supported_load_format}")
 
 
         # TODO: Remove this check once HF updates the pt weights of Mixtral.
         # TODO: Remove this check once HF updates the pt weights of Mixtral.
         architectures = getattr(self.hf_config, "architectures", [])
         architectures = getattr(self.hf_config, "architectures", [])
@@ -166,11 +163,6 @@ class ModelConfig:
             self.max_context_len_to_capture = self.max_model_len
             self.max_context_len_to_capture = self.max_model_len
         self.max_context_len_to_capture = min(self.max_context_len_to_capture,
         self.max_context_len_to_capture = min(self.max_context_len_to_capture,
                                               self.max_model_len)
                                               self.max_model_len)
-        if (self.quantization in ["gptq", "squeezellm"]
-                and not self.enforce_eager):
-            logger.warning(f"{self.quantization} does not support CUDA graph "
-                           "yet. Disabling CUDA graph.")
-            self.enforce_eager = True
 
 
     def verify_with_parallel_config(
     def verify_with_parallel_config(
         self,
         self,

+ 11 - 1
aphrodite/endpoints/openai/api_server.py

@@ -87,6 +87,14 @@ def parse_args():
                         default="assistant",
                         default="assistant",
                         help="The role name to return if "
                         help="The role name to return if "
                         "`request.add_generation_prompt=True.")
                         "`request.add_generation_prompt=True.")
+    parser.add_argument("--ssl-keyfile",
+                        type=str,
+                        default=None,
+                        help="SSL key file path.")
+    parser.add_argument("--ssl-certfile",
+                        type=str,
+                        default=None,
+                        help="SSL cert file path.")
 
 
     parser = AsyncEngineArgs.add_cli_args(parser)
     parser = AsyncEngineArgs.add_cli_args(parser)
     return parser.parse_args()
     return parser.parse_args()
@@ -819,4 +827,6 @@ if __name__ == "__main__":
                 host=args.host,
                 host=args.host,
                 port=args.port,
                 port=args.port,
                 log_level="info",
                 log_level="info",
-                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
+                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
+                ssl_keyfile=args.ssl_keyfile,
+                ssl_certfile=args.ssl_certfile)

+ 16 - 15
aphrodite/engine/aphrodite_engine.py

@@ -1,4 +1,5 @@
 import copy
 import copy
+import os
 import time
 import time
 from functools import partial
 from functools import partial
 from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
 from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
@@ -15,7 +16,6 @@ from aphrodite.common.logger import init_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
 from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
-                                       SequenceGroupMetadata,
                                        SequenceGroupOutput, SequenceOutput,
                                        SequenceGroupOutput, SequenceOutput,
                                        SequenceStatus)
                                        SequenceStatus)
 from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
 from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
@@ -106,6 +106,10 @@ class AphroditeEngine:
 
 
         # Create the parallel GPU workers.
         # Create the parallel GPU workers.
         if self.parallel_config.worker_use_ray:
         if self.parallel_config.worker_use_ray:
+            # Disable Ray usage stats collection.
+            ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
+            if ray_usage != "1":
+                os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
             self._init_workers_ray(placement_group)
             self._init_workers_ray(placement_group)
         else:
         else:
             self._init_workers(distributed_init_method)
             self._init_workers(distributed_init_method)
@@ -250,6 +254,15 @@ class AphroditeEngine:
                              "Try increasing `gpu_memory_utilization` when "
                              "Try increasing `gpu_memory_utilization` when "
                              "initializing the engine.")
                              "initializing the engine.")
 
 
+        max_seq_len = self.cache_config.block_size * num_gpu_blocks
+        if self.model_config.max_model_len > max_seq_len:
+            raise ValueError(
+                f"The model's max seq len ({self.model_config.max_model_len}) "
+                "is larger than the maximum number of tokens that can be "
+                f"stored in KV cache ({max_seq_len}). Try increasing "
+                "`gpu_memory_utilization` or decreasing `max_model_len` when "
+                "initializing the engine.")
+
         self.cache_config.num_gpu_blocks = num_gpu_blocks
         self.cache_config.num_gpu_blocks = num_gpu_blocks
         self.cache_config.num_cpu_blocks = num_cpu_blocks
         self.cache_config.num_cpu_blocks = num_cpu_blocks
 
 
@@ -337,16 +350,6 @@ class AphroditeEngine:
         """Returns True if there are unfinished requests."""
         """Returns True if there are unfinished requests."""
         return self.scheduler.has_unfinished_seqs()
         return self.scheduler.has_unfinished_seqs()
 
 
-    def _schedule(
-        self
-    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
-               List[RequestOutput]]:
-        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
-        return seq_group_metadata_list, scheduler_outputs, [
-            RequestOutput.from_seq_group(seq_group)
-            for seq_group in scheduler_outputs.ignored_seq_groups
-        ]
-
     def _check_beam_search_early_stopping(
     def _check_beam_search_early_stopping(
         self,
         self,
         early_stopping: Union[bool, str],
         early_stopping: Union[bool, str],
@@ -597,9 +600,7 @@ class AphroditeEngine:
         and updates the scheduler with the model outputs. Finally, it decodes
         and updates the scheduler with the model outputs. Finally, it decodes
         the sequences and returns the newly generated results.
         the sequences and returns the newly generated results.
         """
         """
-        seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
-        if scheduler_outputs.is_empty():
-            return ignored
+        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 
 
         # Execute the model.
         # Execute the model.
         output = self._run_workers(
         output = self._run_workers(
@@ -608,7 +609,7 @@ class AphroditeEngine:
             blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
             blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
             blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
             blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
             blocks_to_copy=scheduler_outputs.blocks_to_copy,
             blocks_to_copy=scheduler_outputs.blocks_to_copy,
-        )
+        ) if not scheduler_outputs.is_empty() else []
 
 
         return self._process_model_outputs(output, scheduler_outputs)
         return self._process_model_outputs(output, scheduler_outputs)
 
 

+ 4 - 6
aphrodite/engine/async_aphrodite.py

@@ -182,20 +182,18 @@ class _AsyncAphrodite(AphroditeEngine):
         and updates the scheduler with the model outputs. Finally, it decodes
         and updates the scheduler with the model outputs. Finally, it decodes
         the sequences and returns the newly generated results.
         the sequences and returns the newly generated results.
         """
         """
-        seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
-        if scheduler_outputs.is_empty():
-            return ignored
+        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 
 
         # Execute the model.
         # Execute the model.
-        output = await self._run_workers_async(
+        output = (await self._run_workers_async(
             "execute_model",
             "execute_model",
             seq_group_metadata_list=seq_group_metadata_list,
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
             blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
             blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
             blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
             blocks_to_copy=scheduler_outputs.blocks_to_copy,
             blocks_to_copy=scheduler_outputs.blocks_to_copy,
-        )
+        )) if not scheduler_outputs.is_empty() else []
 
 
-        return self._process_model_outputs(output, scheduler_outputs) + ignored
+        return self._process_model_outputs(output, scheduler_outputs)
 
 
     async def _run_workers_async(
     async def _run_workers_async(
         self,
         self,

+ 3 - 0
aphrodite/modeling/models/gpt_neox.py

@@ -59,6 +59,7 @@ class GPTNeoXAttention(nn.Module):
         self.total_num_heads = config.num_attention_heads
         self.total_num_heads = config.num_attention_heads
         self.hidden_size = config.hidden_size
         self.hidden_size = config.hidden_size
         self.head_size = self.hidden_size // self.total_num_heads
         self.head_size = self.hidden_size // self.total_num_heads
+        self.bias = getattr(config, "attention_bias", True)
 
 
         tensor_model_parallel_world_size = (
         tensor_model_parallel_world_size = (
             get_tensor_model_parallel_world_size())
             get_tensor_model_parallel_world_size())
@@ -70,11 +71,13 @@ class GPTNeoXAttention(nn.Module):
             config.hidden_size,
             config.hidden_size,
             self.head_size,
             self.head_size,
             self.total_num_heads,
             self.total_num_heads,
+            bias=self.bias,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
         self.dense = RowParallelLinear(
         self.dense = RowParallelLinear(
             config.hidden_size,
             config.hidden_size,
             config.hidden_size,
             config.hidden_size,
+            bias=self.bias,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
 
 

+ 5 - 26
aphrodite/modeling/models/mixtral.py

@@ -50,7 +50,6 @@ from aphrodite.modeling.megatron.parallel_state import (
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
                                               hf_model_weights_iterator)
                                               hf_model_weights_iterator)
-from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.common.sequence import SamplerOutput
 
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -95,30 +94,6 @@ class MixtralMLP(nn.Module):
         return current_hidden_states
         return current_hidden_states
 
 
 
 
-class DummyModule(nn.Module):
-
-    def __init__(self) -> None:
-        super().__init__()
-
-        self.w1 = nn.Linear(0, 0, bias=False)
-        self.w2 = nn.Linear(0, 0, bias=False)
-        self.w3 = nn.Linear(0, 0, bias=False)
-
-        set_weight_attrs(self.w1.weight,
-                         {"weight_loader": self.dummy_weight_loader})
-        set_weight_attrs(self.w2.weight,
-                         {"weight_loader": self.dummy_weight_loader})
-        set_weight_attrs(self.w3.weight,
-                         {"weight_loader": self.dummy_weight_loader})
-
-    def forward(self, *args, **kwargs) -> None:
-        raise NotImplementedError()
-
-    def dummy_weight_loader(self, *args, **kwargs) -> None:  # pylint: disable=unused-argument
-        # Noop
-        return
-
-
 class MixtralMoE(nn.Module):
 class MixtralMoE(nn.Module):
 
 
     def __init__(
     def __init__(
@@ -148,7 +123,7 @@ class MixtralMoE(nn.Module):
                        config.hidden_size,
                        config.hidden_size,
                        config.intermediate_size,
                        config.intermediate_size,
                        linear_method=linear_method)
                        linear_method=linear_method)
-            if idx in self.expert_indicies else DummyModule()
+            if idx in self.expert_indicies else None
             for idx in range(self.num_total_experts)
             for idx in range(self.num_total_experts)
         ])
         ])
         self.gate = ReplicatedLinear(config.hidden_size,
         self.gate = ReplicatedLinear(config.hidden_size,
@@ -433,6 +408,10 @@ class MixtralForCausalLM(nn.Module):
                 # Skip loading extra bias for GPTQ models.
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                     continue
+                # Skip experts that are not assigned to this worker.
+                if ("block_sparse_moe.experts." in name
+                        and name not in params_dict):
+                    continue
                 param = params_dict[name]
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                                         default_weight_loader)

+ 3 - 3
aphrodite/processing/block_manager.py

@@ -103,7 +103,7 @@ class BlockSpaceManager:
     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
         # FIXME: Here we assume that all sequences in the group share
         # FIXME: Here we assume that all sequences in the group share
         # the same prompt. This may not be true for preempted sequences.
         # the same prompt. This may not be true for preempted sequences.
-        seq = seq_group.get_seqs()[0]
+        seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
         num_required_blocks = len(seq.logical_token_blocks)
         num_required_blocks = len(seq.logical_token_blocks)
         if self.block_sliding_window is not None:
         if self.block_sliding_window is not None:
             num_required_blocks = min(num_required_blocks,
             num_required_blocks = min(num_required_blocks,
@@ -121,7 +121,7 @@ class BlockSpaceManager:
     def allocate(self, seq_group: SequenceGroup) -> None:
     def allocate(self, seq_group: SequenceGroup) -> None:
         # NOTE: Here we assume that all sequences in the group have the same
         # NOTE: Here we assume that all sequences in the group have the same
         # prompt.
         # prompt.
-        seq = seq_group.get_seqs()[0]
+        seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
 
 
         # Allocate new physical token blocks that will store the prompt tokens.
         # Allocate new physical token blocks that will store the prompt tokens.
         block_table: BlockTable = []
         block_table: BlockTable = []
@@ -136,7 +136,7 @@ class BlockSpaceManager:
             block_table.append(block)
             block_table.append(block)
 
 
         # Assign the block table for each sequence.
         # Assign the block table for each sequence.
-        for seq in seq_group.get_seqs():
+        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
             self.block_tables[seq.seq_id] = block_table.copy()
             self.block_tables[seq.seq_id] = block_table.copy()
 
 
     def can_append_slot(self, seq_group: SequenceGroup) -> bool:
     def can_append_slot(self, seq_group: SequenceGroup) -> bool:

+ 7 - 5
aphrodite/processing/scheduler.py

@@ -139,15 +139,17 @@ class Scheduler:
             while self.waiting:
             while self.waiting:
                 seq_group = self.waiting[0]
                 seq_group = self.waiting[0]
 
 
-                assert seq_group.num_seqs() == 1, (
+                waiting_seqs = seq_group.get_seqs(
+                    status=SequenceStatus.WAITING)
+                assert len(waiting_seqs) == 1, (
                     "Waiting sequence group should have only one prompt "
                     "Waiting sequence group should have only one prompt "
                     "sequence.")
                     "sequence.")
-                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
+                num_prompt_tokens = waiting_seqs[0].get_len()
                 if num_prompt_tokens > self.prompt_limit:
                 if num_prompt_tokens > self.prompt_limit:
                     logger.warning(
                     logger.warning(
                         f"Input prompt ({num_prompt_tokens} tokens) is too long"
                         f"Input prompt ({num_prompt_tokens} tokens) is too long"
                         f" and exceeds limit of {self.prompt_limit}")
                         f" and exceeds limit of {self.prompt_limit}")
-                    for seq in seq_group.get_seqs():
+                    for seq in waiting_seqs:
                         seq.status = SequenceStatus.FINISHED_IGNORED
                         seq.status = SequenceStatus.FINISHED_IGNORED
                     ignored_seq_groups.append(seq_group)
                     ignored_seq_groups.append(seq_group)
                     self.waiting.pop(0)
                     self.waiting.pop(0)
@@ -161,7 +163,7 @@ class Scheduler:
                     logger.warning(
                     logger.warning(
                         f"Input prompt ({num_prompt_tokens} tokens) is too long"
                         f"Input prompt ({num_prompt_tokens} tokens) is too long"
                         f" and exceeds the capacity of the block manager.")
                         f" and exceeds the capacity of the block manager.")
-                    for seq in seq_group.get_seqs():
+                    for seq in waiting_seqs:
                         seq.status = SequenceStatus.FINISHED_IGNORED
                         seq.status = SequenceStatus.FINISHED_IGNORED
                     ignored_seq_groups.append(seq_group)
                     ignored_seq_groups.append(seq_group)
                     self.waiting.pop(0)
                     self.waiting.pop(0)
@@ -320,7 +322,7 @@ class Scheduler:
 
 
     def _allocate(self, seq_group: SequenceGroup) -> None:
     def _allocate(self, seq_group: SequenceGroup) -> None:
         self.block_manager.allocate(seq_group)
         self.block_manager.allocate(seq_group)
-        for seq in seq_group.get_seqs():
+        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
             seq.status = SequenceStatus.RUNNING
             seq.status = SequenceStatus.RUNNING
 
 
     def _append_slot(
     def _append_slot(

+ 3 - 0
kernels/activation_kernels.cu

@@ -1,4 +1,5 @@
 #include <torch/extension.h>
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/CUDAContext.h>
 
 
 #include "cuda_compat.h"
 #include "cuda_compat.h"
@@ -36,6 +37,7 @@ void silu_and_mul(
 
 
   dim3 grid(num_tokens);
   dim3 grid(num_tokens);
   dim3 block(std::min(d, 1024));
   dim3 block(std::min(d, 1024));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
     input.scalar_type(),
     input.scalar_type(),
@@ -71,6 +73,7 @@ __global__ void activation_kernel(
   int64_t num_tokens = input.numel() / d;                                                 \
   int64_t num_tokens = input.numel() / d;                                                 \
   dim3 grid(num_tokens);                                                                  \
   dim3 grid(num_tokens);                                                                  \
   dim3 block(std::min(d, 1024));                                                          \
   dim3 block(std::min(d, 1024));                                                          \
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
   APHRODITE_DISPATCH_FLOATING_TYPES(                                                           \
   APHRODITE_DISPATCH_FLOATING_TYPES(                                                           \
     input.scalar_type(),                                                                  \
     input.scalar_type(),                                                                  \

+ 3 - 0
kernels/attention/attention_kernels.cu

@@ -21,6 +21,7 @@
 #endif
 #endif
 
 
 #include <torch/extension.h>
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/CUDAContext.h>
 
 
 #include "attention_dtypes.h"
 #include "attention_dtypes.h"
@@ -616,6 +617,7 @@ void paged_attention_v1_launcher(
 
 
   dim3 grid(num_heads, num_seqs, 1);
   dim3 grid(num_heads, num_seqs, 1);
   dim3 block(NUM_THREADS);
   dim3 block(NUM_THREADS);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   switch (head_size) {
   switch (head_size) {
     // NOTE: To reduce the compilation time, we only compile for the
     // NOTE: To reduce the compilation time, we only compile for the
@@ -784,6 +786,7 @@ void paged_attention_v2_launcher(
   int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
   int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
 
 
   dim3 block(NUM_THREADS);
   dim3 block(NUM_THREADS);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   switch (head_size) {
   switch (head_size) {
     // NOTE: To reduce the compilation time, we only compile for the
     // NOTE: To reduce the compilation time, we only compile for the

+ 5 - 0
kernels/cache_kernels.cu

@@ -1,4 +1,5 @@
 #include <torch/extension.h>
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/CUDAContext.h>
 
 
 #include "cuda_compat.h"
 #include "cuda_compat.h"
@@ -33,6 +34,7 @@ void swap_blocks(
   char *dst_ptr = static_cast<char*>(dst.data_ptr());
   char *dst_ptr = static_cast<char*>(dst.data_ptr());
 
 
   const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
   const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
+  const at::cuda::OptionalCUDAGuard device_guard(src_device);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   // NOTE: This can be slow if the number of blocks is large.
   // NOTE: This can be slow if the number of blocks is large.
   for (const auto& pair : block_mapping) {
   for (const auto& pair : block_mapping) {
@@ -127,6 +129,7 @@ void copy_blocks(
   const int numel_per_block = key_caches[0][0].numel();
   const int numel_per_block = key_caches[0][0].numel();
   dim3 grid(num_layers, num_pairs);
   dim3 grid(num_layers, num_pairs);
   dim3 block(std::min(1024, numel_per_block));
   dim3 block(std::min(1024, numel_per_block));
+  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
     key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
     key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
@@ -207,6 +210,7 @@ void reshape_and_cache(
 
 
   dim3 grid(num_tokens);
   dim3 grid(num_tokens);
   dim3 block(std::min(num_heads * head_size, 512));
   dim3 block(std::min(num_heads * head_size, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
     key.scalar_type(),
     key.scalar_type(),
@@ -367,6 +371,7 @@ void gather_cached_kv(
 
 
   dim3 grid(num_tokens);
   dim3 grid(num_tokens);
   dim3 block(std::min(num_heads * head_size, 512));
   dim3 block(std::min(num_heads * head_size, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
     key.scalar_type(),
     key.scalar_type(),

+ 3 - 0
kernels/layernorm_kernels.cu

@@ -1,4 +1,5 @@
 #include <torch/extension.h>
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/CUDAContext.h>
 
 
 #include "dispatch_utils.h"
 #include "dispatch_utils.h"
@@ -75,6 +76,7 @@ void rms_norm(
 
 
   dim3 grid(num_tokens);
   dim3 grid(num_tokens);
   dim3 block(std::min(hidden_size, 1024));
   dim3 block(std::min(hidden_size, 1024));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
     input.scalar_type(),
     input.scalar_type(),
@@ -100,6 +102,7 @@ void fused_add_rms_norm(
 
 
   dim3 grid(num_tokens);
   dim3 grid(num_tokens);
   dim3 block(std::min(hidden_size, 1024));
   dim3 block(std::min(hidden_size, 1024));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
     input.scalar_type(),
     input.scalar_type(),

+ 2 - 0
kernels/pos_encoding_kernels.cu

@@ -1,4 +1,5 @@
 #include <torch/extension.h>
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/CUDAContext.h>
 
 
 #include "cuda_compat.h"
 #include "cuda_compat.h"
@@ -94,6 +95,7 @@ void rotary_embedding(
 
 
   dim3 grid(num_tokens);
   dim3 grid(num_tokens);
   dim3 block(std::min(num_heads * rot_dim / 2, 512));
   dim3 block(std::min(num_heads * rot_dim / 2, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
     query.scalar_type(),
     query.scalar_type(),

+ 1 - 1
kernels/quantization/gptq/matrix_view.cuh

@@ -147,5 +147,5 @@ public:
 };
 };
 
 
 }  // namespace gptq
 }  // namespace gptq
-}  // namespace vllm
+}  // namespace aphrodite
 #endif
 #endif

+ 23 - 7
kernels/quantization/gptq/q_gemm.cu

@@ -28,6 +28,7 @@ namespace gptq {
 #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
 #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
 
 
 #if defined(USE_ROCM)
 #if defined(USE_ROCM)
+#include <hipblas/hipblas.h>
 __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,
 __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,
                                                                hipblasOperation_t transA,
                                                                hipblasOperation_t transA,
                                                                hipblasOperation_t transB,
                                                                hipblasOperation_t transB,
@@ -286,7 +287,8 @@ void gemm_half_q_half_cuda_part
 
 
     fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
     fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
 
 
-    kernel<<<gridDim, blockDim>>>
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    kernel<<<gridDim, blockDim, 0, stream>>>
     (
     (
         a,
         a,
         b_q_weight,
         b_q_weight,
@@ -433,7 +435,8 @@ void reconstruct_exllama
     gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
     gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
 
 
-    reconstruct_exllama_kernel<<<gridDim, blockDim>>>
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
     (
     (
         b_q_weight,
         b_q_weight,
         b_q_perm,
         b_q_perm,
@@ -520,16 +523,25 @@ __global__ void gemm_half_q_half_alt_kernel(
             zeros_tmp[tmp_k] = zero;
             zeros_tmp[tmp_k] = zero;
         }
         }
         for (int m = 0; m < b_end; m++) {
         for (int m = 0; m < b_end; m++) {
+#ifndef USE_ROCM
             res2 = {};
             res2 = {};
+#else
+            res2.x = __half_as_ushort(__float2half(0));
+            res2.y = __half_as_ushort(__float2half(0));
+#endif
             res2 = __hfma2(__hfma2(deq2[(tmp >>  0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >>  0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >>  8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >>  8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
+#ifndef USE_ROCM
             res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
             res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
+#else
+            res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
+#endif
         }
         }
         i += width;
         i += width;
         k += 4;
         k += 4;
-    }
+}
     for (int m = 0; m < b_end; m++) {
     for (int m = 0; m < b_end; m++) {
         atomicAdd(&mul[(b + m) * width + w], res[m]);
         atomicAdd(&mul[(b + m) * width + w], res[m]);
     }
     }
@@ -557,7 +569,8 @@ void gemm_half_q_half_alt
     gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
     gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
     gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
     gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
 
 
-    gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
     (
     (
         (const half2*) a,
         (const half2*) a,
         b_q_weight,
         b_q_weight,
@@ -629,7 +642,8 @@ void reconstruct_gptq
     blockDim.y = 1;
     blockDim.y = 1;
     gridDim.y = DIVIDE(height, 8);
     gridDim.y = DIVIDE(height, 8);
     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
-    reconstruct_gptq_kernel<<<gridDim, blockDim>>>
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
     (
     (
         b_q_weight,
         b_q_weight,
         b_gptq_scales,
         b_gptq_scales,
@@ -784,7 +798,8 @@ void shuffle_exllama_weight
         gridDim.x = DIVIDE(width, THREADS_X);
         gridDim.x = DIVIDE(width, THREADS_X);
         gridDim.y = height / 8;
         gridDim.y = height / 8;
 
 
-        make_sequential_kernel<<<gridDim, blockDim>>>
+        const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+        make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
         (
         (
             q_weight,
             q_weight,
             new_qweight,
             new_qweight,
@@ -803,7 +818,8 @@ void shuffle_exllama_weight
     blockDim.y = 1;
     blockDim.y = 1;
     gridDim.x = DIVIDE(width, THREADS_X);
     gridDim.x = DIVIDE(width, THREADS_X);
     gridDim.y = 1;
     gridDim.y = 1;
-    shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
 }
 }
 
 
 }  // namespace gptq
 }  // namespace gptq

+ 4 - 2
kernels/quantization/squeezellm/quant_cuda_kernel.cu

@@ -7,6 +7,7 @@
 // half-tensor
 // half-tensor
 #include <c10/cuda/CUDAStream.h>
 #include <c10/cuda/CUDAStream.h>
 #include <ATen/cuda/CUDATensorMethods.cuh>
 #include <ATen/cuda/CUDATensorMethods.cuh>
+#include <c10/cuda/CUDAGuard.h>
 
 
 #define BLOCKWIDTH 128
 #define BLOCKWIDTH 128
 #define BLOCKHEIGHT4 16
 #define BLOCKHEIGHT4 16
@@ -199,8 +200,9 @@ void squeezellm_gemm(
     (width + BLOCKWIDTH - 1) / BLOCKWIDTH
     (width + BLOCKWIDTH - 1) / BLOCKWIDTH
   );
   );
   dim3 threads(BLOCKWIDTH);
   dim3 threads(BLOCKWIDTH);
-
-  aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
 #ifndef USE_ROCM
 #ifndef USE_ROCM
     (half2*) vec.data<at::Half>(),
     (half2*) vec.data<at::Half>(),
 #else
 #else