ソースを参照

chore: backlogs 1 (#191)

AlpinDale 1 年間 前
コミット
b9b295d74e

+ 10 - 18
aphrodite/common/config.py

@@ -95,23 +95,20 @@ class ModelConfig:
         supported_load_format = [
             "auto", "pt", "safetensors", "npcache", "dummy"
         ]
-        rocm_not_supported_load_format = ["safetensors"]
+        rocm_not_supported_load_format = []
         if load_format not in supported_load_format:
             raise ValueError(
                 f"Unknown load format: {self.load_format}. Must be one of "
                 "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
-        if is_hip():
-            if load_format in ["safetensors"]:
-                rocm_supported_load_format = [
-                    f for f in supported_load_format
-                    if (f not in rocm_not_supported_load_format)
-                ]
-                raise ValueError(
-                    f"load format {load_format} is not supported on ROCm. "
-                    f"Must be one of {rocm_supported_load_format}.")
-            # force ROCm to load from pt weights if nothing is set
-            if load_format == "auto":
-                load_format = "pt"
+        if is_hip() and load_format in rocm_not_supported_load_format:
+            rocm_supported_load_format = [
+                f for f in supported_load_format
+                if (f not in rocm_not_supported_load_format)
+            ]
+            raise ValueError(
+                f"load format \'{load_format}\' is not supported in ROCm. "
+                f"Supported load format are "
+                f"{rocm_supported_load_format}")
 
         # TODO: Remove this check once HF updates the pt weights of Mixtral.
         architectures = getattr(self.hf_config, "architectures", [])
@@ -166,11 +163,6 @@ class ModelConfig:
             self.max_context_len_to_capture = self.max_model_len
         self.max_context_len_to_capture = min(self.max_context_len_to_capture,
                                               self.max_model_len)
-        if (self.quantization in ["gptq", "squeezellm"]
-                and not self.enforce_eager):
-            logger.warning(f"{self.quantization} does not support CUDA graph "
-                           "yet. Disabling CUDA graph.")
-            self.enforce_eager = True
 
     def verify_with_parallel_config(
         self,

+ 11 - 1
aphrodite/endpoints/openai/api_server.py

@@ -87,6 +87,14 @@ def parse_args():
                         default="assistant",
                         help="The role name to return if "
                         "`request.add_generation_prompt=True.")
+    parser.add_argument("--ssl-keyfile",
+                        type=str,
+                        default=None,
+                        help="SSL key file path.")
+    parser.add_argument("--ssl-certfile",
+                        type=str,
+                        default=None,
+                        help="SSL cert file path.")
 
     parser = AsyncEngineArgs.add_cli_args(parser)
     return parser.parse_args()
@@ -819,4 +827,6 @@ if __name__ == "__main__":
                 host=args.host,
                 port=args.port,
                 log_level="info",
-                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
+                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
+                ssl_keyfile=args.ssl_keyfile,
+                ssl_certfile=args.ssl_certfile)

+ 16 - 15
aphrodite/engine/aphrodite_engine.py

@@ -1,4 +1,5 @@
 import copy
+import os
 import time
 from functools import partial
 from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
@@ -15,7 +16,6 @@ from aphrodite.common.logger import init_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
-                                       SequenceGroupMetadata,
                                        SequenceGroupOutput, SequenceOutput,
                                        SequenceStatus)
 from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
@@ -106,6 +106,10 @@ class AphroditeEngine:
 
         # Create the parallel GPU workers.
         if self.parallel_config.worker_use_ray:
+            # Disable Ray usage stats collection.
+            ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
+            if ray_usage != "1":
+                os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
             self._init_workers_ray(placement_group)
         else:
             self._init_workers(distributed_init_method)
@@ -250,6 +254,15 @@ class AphroditeEngine:
                              "Try increasing `gpu_memory_utilization` when "
                              "initializing the engine.")
 
+        max_seq_len = self.cache_config.block_size * num_gpu_blocks
+        if self.model_config.max_model_len > max_seq_len:
+            raise ValueError(
+                f"The model's max seq len ({self.model_config.max_model_len}) "
+                "is larger than the maximum number of tokens that can be "
+                f"stored in KV cache ({max_seq_len}). Try increasing "
+                "`gpu_memory_utilization` or decreasing `max_model_len` when "
+                "initializing the engine.")
+
         self.cache_config.num_gpu_blocks = num_gpu_blocks
         self.cache_config.num_cpu_blocks = num_cpu_blocks
 
@@ -337,16 +350,6 @@ class AphroditeEngine:
         """Returns True if there are unfinished requests."""
         return self.scheduler.has_unfinished_seqs()
 
-    def _schedule(
-        self
-    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
-               List[RequestOutput]]:
-        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
-        return seq_group_metadata_list, scheduler_outputs, [
-            RequestOutput.from_seq_group(seq_group)
-            for seq_group in scheduler_outputs.ignored_seq_groups
-        ]
-
     def _check_beam_search_early_stopping(
         self,
         early_stopping: Union[bool, str],
@@ -597,9 +600,7 @@ class AphroditeEngine:
         and updates the scheduler with the model outputs. Finally, it decodes
         the sequences and returns the newly generated results.
         """
-        seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
-        if scheduler_outputs.is_empty():
-            return ignored
+        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 
         # Execute the model.
         output = self._run_workers(
@@ -608,7 +609,7 @@ class AphroditeEngine:
             blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
             blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
             blocks_to_copy=scheduler_outputs.blocks_to_copy,
-        )
+        ) if not scheduler_outputs.is_empty() else []
 
         return self._process_model_outputs(output, scheduler_outputs)
 

+ 4 - 6
aphrodite/engine/async_aphrodite.py

@@ -182,20 +182,18 @@ class _AsyncAphrodite(AphroditeEngine):
         and updates the scheduler with the model outputs. Finally, it decodes
         the sequences and returns the newly generated results.
         """
-        seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
-        if scheduler_outputs.is_empty():
-            return ignored
+        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 
         # Execute the model.
-        output = await self._run_workers_async(
+        output = (await self._run_workers_async(
             "execute_model",
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
             blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
             blocks_to_copy=scheduler_outputs.blocks_to_copy,
-        )
+        )) if not scheduler_outputs.is_empty() else []
 
-        return self._process_model_outputs(output, scheduler_outputs) + ignored
+        return self._process_model_outputs(output, scheduler_outputs)
 
     async def _run_workers_async(
         self,

+ 3 - 0
aphrodite/modeling/models/gpt_neox.py

@@ -59,6 +59,7 @@ class GPTNeoXAttention(nn.Module):
         self.total_num_heads = config.num_attention_heads
         self.hidden_size = config.hidden_size
         self.head_size = self.hidden_size // self.total_num_heads
+        self.bias = getattr(config, "attention_bias", True)
 
         tensor_model_parallel_world_size = (
             get_tensor_model_parallel_world_size())
@@ -70,11 +71,13 @@ class GPTNeoXAttention(nn.Module):
             config.hidden_size,
             self.head_size,
             self.total_num_heads,
+            bias=self.bias,
             linear_method=linear_method,
         )
         self.dense = RowParallelLinear(
             config.hidden_size,
             config.hidden_size,
+            bias=self.bias,
             linear_method=linear_method,
         )
 

+ 5 - 26
aphrodite/modeling/models/mixtral.py

@@ -50,7 +50,6 @@ from aphrodite.modeling.megatron.parallel_state import (
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
                                               hf_model_weights_iterator)
-from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -95,30 +94,6 @@ class MixtralMLP(nn.Module):
         return current_hidden_states
 
 
-class DummyModule(nn.Module):
-
-    def __init__(self) -> None:
-        super().__init__()
-
-        self.w1 = nn.Linear(0, 0, bias=False)
-        self.w2 = nn.Linear(0, 0, bias=False)
-        self.w3 = nn.Linear(0, 0, bias=False)
-
-        set_weight_attrs(self.w1.weight,
-                         {"weight_loader": self.dummy_weight_loader})
-        set_weight_attrs(self.w2.weight,
-                         {"weight_loader": self.dummy_weight_loader})
-        set_weight_attrs(self.w3.weight,
-                         {"weight_loader": self.dummy_weight_loader})
-
-    def forward(self, *args, **kwargs) -> None:
-        raise NotImplementedError()
-
-    def dummy_weight_loader(self, *args, **kwargs) -> None:  # pylint: disable=unused-argument
-        # Noop
-        return
-
-
 class MixtralMoE(nn.Module):
 
     def __init__(
@@ -148,7 +123,7 @@ class MixtralMoE(nn.Module):
                        config.hidden_size,
                        config.intermediate_size,
                        linear_method=linear_method)
-            if idx in self.expert_indicies else DummyModule()
+            if idx in self.expert_indicies else None
             for idx in range(self.num_total_experts)
         ])
         self.gate = ReplicatedLinear(config.hidden_size,
@@ -433,6 +408,10 @@ class MixtralForCausalLM(nn.Module):
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                     continue
+                # Skip experts that are not assigned to this worker.
+                if ("block_sparse_moe.experts." in name
+                        and name not in params_dict):
+                    continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)

+ 3 - 3
aphrodite/processing/block_manager.py

@@ -103,7 +103,7 @@ class BlockSpaceManager:
     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
         # FIXME: Here we assume that all sequences in the group share
         # the same prompt. This may not be true for preempted sequences.
-        seq = seq_group.get_seqs()[0]
+        seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
         num_required_blocks = len(seq.logical_token_blocks)
         if self.block_sliding_window is not None:
             num_required_blocks = min(num_required_blocks,
@@ -121,7 +121,7 @@ class BlockSpaceManager:
     def allocate(self, seq_group: SequenceGroup) -> None:
         # NOTE: Here we assume that all sequences in the group have the same
         # prompt.
-        seq = seq_group.get_seqs()[0]
+        seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
 
         # Allocate new physical token blocks that will store the prompt tokens.
         block_table: BlockTable = []
@@ -136,7 +136,7 @@ class BlockSpaceManager:
             block_table.append(block)
 
         # Assign the block table for each sequence.
-        for seq in seq_group.get_seqs():
+        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
             self.block_tables[seq.seq_id] = block_table.copy()
 
     def can_append_slot(self, seq_group: SequenceGroup) -> bool:

+ 7 - 5
aphrodite/processing/scheduler.py

@@ -139,15 +139,17 @@ class Scheduler:
             while self.waiting:
                 seq_group = self.waiting[0]
 
-                assert seq_group.num_seqs() == 1, (
+                waiting_seqs = seq_group.get_seqs(
+                    status=SequenceStatus.WAITING)
+                assert len(waiting_seqs) == 1, (
                     "Waiting sequence group should have only one prompt "
                     "sequence.")
-                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
+                num_prompt_tokens = waiting_seqs[0].get_len()
                 if num_prompt_tokens > self.prompt_limit:
                     logger.warning(
                         f"Input prompt ({num_prompt_tokens} tokens) is too long"
                         f" and exceeds limit of {self.prompt_limit}")
-                    for seq in seq_group.get_seqs():
+                    for seq in waiting_seqs:
                         seq.status = SequenceStatus.FINISHED_IGNORED
                     ignored_seq_groups.append(seq_group)
                     self.waiting.pop(0)
@@ -161,7 +163,7 @@ class Scheduler:
                     logger.warning(
                         f"Input prompt ({num_prompt_tokens} tokens) is too long"
                         f" and exceeds the capacity of the block manager.")
-                    for seq in seq_group.get_seqs():
+                    for seq in waiting_seqs:
                         seq.status = SequenceStatus.FINISHED_IGNORED
                     ignored_seq_groups.append(seq_group)
                     self.waiting.pop(0)
@@ -320,7 +322,7 @@ class Scheduler:
 
     def _allocate(self, seq_group: SequenceGroup) -> None:
         self.block_manager.allocate(seq_group)
-        for seq in seq_group.get_seqs():
+        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
             seq.status = SequenceStatus.RUNNING
 
     def _append_slot(

+ 3 - 0
kernels/activation_kernels.cu

@@ -1,4 +1,5 @@
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 
 #include "cuda_compat.h"
@@ -36,6 +37,7 @@ void silu_and_mul(
 
   dim3 grid(num_tokens);
   dim3 block(std::min(d, 1024));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
     input.scalar_type(),
@@ -71,6 +73,7 @@ __global__ void activation_kernel(
   int64_t num_tokens = input.numel() / d;                                                 \
   dim3 grid(num_tokens);                                                                  \
   dim3 block(std::min(d, 1024));                                                          \
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
   APHRODITE_DISPATCH_FLOATING_TYPES(                                                           \
     input.scalar_type(),                                                                  \

+ 3 - 0
kernels/attention/attention_kernels.cu

@@ -21,6 +21,7 @@
 #endif
 
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 
 #include "attention_dtypes.h"
@@ -616,6 +617,7 @@ void paged_attention_v1_launcher(
 
   dim3 grid(num_heads, num_seqs, 1);
   dim3 block(NUM_THREADS);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   switch (head_size) {
     // NOTE: To reduce the compilation time, we only compile for the
@@ -784,6 +786,7 @@ void paged_attention_v2_launcher(
   int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
 
   dim3 block(NUM_THREADS);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   switch (head_size) {
     // NOTE: To reduce the compilation time, we only compile for the

+ 5 - 0
kernels/cache_kernels.cu

@@ -1,4 +1,5 @@
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 
 #include "cuda_compat.h"
@@ -33,6 +34,7 @@ void swap_blocks(
   char *dst_ptr = static_cast<char*>(dst.data_ptr());
 
   const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
+  const at::cuda::OptionalCUDAGuard device_guard(src_device);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   // NOTE: This can be slow if the number of blocks is large.
   for (const auto& pair : block_mapping) {
@@ -127,6 +129,7 @@ void copy_blocks(
   const int numel_per_block = key_caches[0][0].numel();
   dim3 grid(num_layers, num_pairs);
   dim3 block(std::min(1024, numel_per_block));
+  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
     key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
@@ -207,6 +210,7 @@ void reshape_and_cache(
 
   dim3 grid(num_tokens);
   dim3 block(std::min(num_heads * head_size, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
     key.scalar_type(),
@@ -367,6 +371,7 @@ void gather_cached_kv(
 
   dim3 grid(num_tokens);
   dim3 block(std::min(num_heads * head_size, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
     key.scalar_type(),

+ 3 - 0
kernels/layernorm_kernels.cu

@@ -1,4 +1,5 @@
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 
 #include "dispatch_utils.h"
@@ -75,6 +76,7 @@ void rms_norm(
 
   dim3 grid(num_tokens);
   dim3 block(std::min(hidden_size, 1024));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
     input.scalar_type(),
@@ -100,6 +102,7 @@ void fused_add_rms_norm(
 
   dim3 grid(num_tokens);
   dim3 block(std::min(hidden_size, 1024));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
     input.scalar_type(),

+ 2 - 0
kernels/pos_encoding_kernels.cu

@@ -1,4 +1,5 @@
 #include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 
 #include "cuda_compat.h"
@@ -94,6 +95,7 @@ void rotary_embedding(
 
   dim3 grid(num_tokens);
   dim3 block(std::min(num_heads * rot_dim / 2, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
     query.scalar_type(),

+ 1 - 1
kernels/quantization/gptq/matrix_view.cuh

@@ -147,5 +147,5 @@ public:
 };
 
 }  // namespace gptq
-}  // namespace vllm
+}  // namespace aphrodite
 #endif

+ 23 - 7
kernels/quantization/gptq/q_gemm.cu

@@ -28,6 +28,7 @@ namespace gptq {
 #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
 
 #if defined(USE_ROCM)
+#include <hipblas/hipblas.h>
 __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,
                                                                hipblasOperation_t transA,
                                                                hipblasOperation_t transB,
@@ -286,7 +287,8 @@ void gemm_half_q_half_cuda_part
 
     fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
 
-    kernel<<<gridDim, blockDim>>>
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    kernel<<<gridDim, blockDim, 0, stream>>>
     (
         a,
         b_q_weight,
@@ -433,7 +435,8 @@ void reconstruct_exllama
     gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
 
-    reconstruct_exllama_kernel<<<gridDim, blockDim>>>
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
     (
         b_q_weight,
         b_q_perm,
@@ -520,16 +523,25 @@ __global__ void gemm_half_q_half_alt_kernel(
             zeros_tmp[tmp_k] = zero;
         }
         for (int m = 0; m < b_end; m++) {
+#ifndef USE_ROCM
             res2 = {};
+#else
+            res2.x = __half_as_ushort(__float2half(0));
+            res2.y = __half_as_ushort(__float2half(0));
+#endif
             res2 = __hfma2(__hfma2(deq2[(tmp >>  0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >>  8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
             res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
+#ifndef USE_ROCM
             res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
+#else
+            res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
+#endif
         }
         i += width;
         k += 4;
-    }
+}
     for (int m = 0; m < b_end; m++) {
         atomicAdd(&mul[(b + m) * width + w], res[m]);
     }
@@ -557,7 +569,8 @@ void gemm_half_q_half_alt
     gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
     gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
 
-    gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
     (
         (const half2*) a,
         b_q_weight,
@@ -629,7 +642,8 @@ void reconstruct_gptq
     blockDim.y = 1;
     gridDim.y = DIVIDE(height, 8);
     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
-    reconstruct_gptq_kernel<<<gridDim, blockDim>>>
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
     (
         b_q_weight,
         b_gptq_scales,
@@ -784,7 +798,8 @@ void shuffle_exllama_weight
         gridDim.x = DIVIDE(width, THREADS_X);
         gridDim.y = height / 8;
 
-        make_sequential_kernel<<<gridDim, blockDim>>>
+        const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+        make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
         (
             q_weight,
             new_qweight,
@@ -803,7 +818,8 @@ void shuffle_exllama_weight
     blockDim.y = 1;
     gridDim.x = DIVIDE(width, THREADS_X);
     gridDim.y = 1;
-    shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
 }
 
 }  // namespace gptq

+ 4 - 2
kernels/quantization/squeezellm/quant_cuda_kernel.cu

@@ -7,6 +7,7 @@
 // half-tensor
 #include <c10/cuda/CUDAStream.h>
 #include <ATen/cuda/CUDATensorMethods.cuh>
+#include <c10/cuda/CUDAGuard.h>
 
 #define BLOCKWIDTH 128
 #define BLOCKHEIGHT4 16
@@ -199,8 +200,9 @@ void squeezellm_gemm(
     (width + BLOCKWIDTH - 1) / BLOCKWIDTH
   );
   dim3 threads(BLOCKWIDTH);
-
-  aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
 #ifndef USE_ROCM
     (half2*) vec.data<at::Half>(),
 #else