Explorar o código

feat: allow further support for non-cuda devices (#247)

* feat: allow further support for non-cuda devices

* formatting
AlpinDale hai 1 ano
pai
achega
ea0f57b233

+ 6 - 0
aphrodite/common/config.py

@@ -444,6 +444,12 @@ class SchedulerConfig:
                 f"({self.max_num_seqs}).")
 
 
+class DeviceConfig:
+
+    def __init__(self, device: str = "cuda") -> None:
+        self.device = torch.device(device)
+
+
 @dataclass
 class LoRAConfig:
     max_lora_rank: int

+ 2 - 1
aphrodite/common/utils.py

@@ -224,7 +224,8 @@ def create_kv_caches_with_random(
     device: Optional[str] = "cuda",
 ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
     torch.random.manual_seed(seed)
-    torch.cuda.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
 
     if isinstance(cache_dtype, str):
         if cache_dtype == "auto":

+ 9 - 1
aphrodite/engine/aphrodite_engine.py

@@ -7,7 +7,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
 
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig)
+                                     SchedulerConfig, LoRAConfig, DeviceConfig)
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics import StatLogger, Stats
@@ -56,6 +56,7 @@ class AphroditeEngine:
             management.
         parallel_config: The configuration related to distributed execution.
         scheduler_config: The configuration related to the request scheduler.
+        device_config: The configuration related to the device.
         lora_config: The configuration related to LoRA.
         placement_group: Ray placement group for distributed execution.
             Required for distributed execution.
@@ -68,6 +69,7 @@ class AphroditeEngine:
         cache_config: CacheConfig,
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
         placement_group: Optional["PlacementGroup"],
         log_stats: bool,
@@ -90,6 +92,7 @@ class AphroditeEngine:
             f"Context Length = {model_config.max_model_len}\n"
             f"Enforce Eager Mode = {model_config.enforce_eager}\n"
             f"KV Cache Data Type = {cache_config.cache_dtype}\n"
+            f"Device = {device_config.device}\n"
             f"Seed = {model_config.seed}")
         # TODO: Print more configs in debug mode.
 
@@ -98,6 +101,7 @@ class AphroditeEngine:
         self.lora_config = lora_config
         self.parallel_config = parallel_config
         self.scheduler_config = scheduler_config
+        self.device_config = device_config
         self.log_stats = log_stats
         self._verify_args()
 
@@ -144,6 +148,7 @@ class AphroditeEngine:
             self.model_config,
             self.parallel_config,
             self.scheduler_config,
+            self.device_config,
             local_rank=0,
             rank=0,
             distributed_init_method=distributed_init_method,
@@ -239,6 +244,7 @@ class AphroditeEngine:
         model_config = copy.deepcopy(self.model_config)
         parallel_config = copy.deepcopy(self.parallel_config)
         scheduler_config = copy.deepcopy(self.scheduler_config)
+        device_config = copy.deepcopy(self.device_config)
 
         for rank, (worker, (node_id,
                             _)) in enumerate(zip(self.workers,
@@ -250,6 +256,7 @@ class AphroditeEngine:
                     model_config,
                     parallel_config,
                     scheduler_config,
+                    device_config,
                     local_rank,
                     rank,
                     distributed_init_method,
@@ -263,6 +270,7 @@ class AphroditeEngine:
             model_config,
             parallel_config,
             scheduler_config,
+            device_config,
             driver_local_rank,
             driver_rank,
             distributed_init_method,

+ 15 - 6
aphrodite/engine/args_tools.py

@@ -4,7 +4,7 @@ from dataclasses import dataclass
 from typing import Optional, Tuple
 
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig)
+                                     SchedulerConfig, LoRAConfig, DeviceConfig)
 
 
 @dataclass
@@ -43,6 +43,7 @@ class EngineArgs:
     lora_extra_vocab_size: int = 256
     lora_dtype = 'auto'
     max_cpu_loras: Optional[int] = None
+    device: str = 'cuda'
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -127,13 +128,13 @@ class EngineArgs:
             '--kv-cache-dtype',
             type=str,
             choices=['auto', 'fp8_e5m2'],
-            default='auto',
+            default=EngineArgs.kv_cache_dtype,
             help='Data type for kv cache storage. If "auto", will use model '
             'data type. Note FP8 is not supported when cuda version is '
             'lower than 11.8.')
         parser.add_argument('--max-model-len',
                             type=int,
-                            default=None,
+                            default=EngineArgs.max_model_len,
                             help='model context length. If unspecified, '
                             'will be automatically derived from the model.')
         # Parallel arguments
@@ -154,6 +155,7 @@ class EngineArgs:
         parser.add_argument(
             '--max-parallel-loading-workers',
             type=int,
+            default=EngineArgs.max_parallel_loading_workers,
             help='load model sequentially in multiple batches, '
             'to avoid RAM OOM when using tensor '
             'parallel and large models')
@@ -202,7 +204,7 @@ class EngineArgs:
             '-q',
             type=str,
             choices=['awq', 'gguf', 'gptq', 'quip', 'squeezellm', None],
-            default=None,
+            default=EngineArgs.quantization,
             help='Method used to quantize the weights. If '
             'None, we first check the `quantization_config` '
             'attribute in the model config file. If that is '
@@ -257,6 +259,12 @@ class EngineArgs:
             help=('Maximum number of LoRAs to store in CPU memory. '
                   'Must be >= than max_num_seqs. '
                   'Defaults to max_num_seqs.'))
+        parser.add_argument('--device',
+                            type=str,
+                            default=EngineArgs.device,
+                            choices=['cuda'],
+                            help=('Device to use for model execution. '
+                                  'Currently, only "cuda" is supported.'))
         return parser
 
     @classmethod
@@ -270,7 +278,8 @@ class EngineArgs:
     def create_engine_configs(
         self,
     ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
-               Optional[LoRAConfig]]:
+               DeviceConfig, Optional[LoRAConfig]]:
+        device_config = DeviceConfig(self.device)
         model_config = ModelConfig(self.model, self.tokenizer,
                                    self.tokenizer_mode, self.trust_remote_code,
                                    self.download_dir, self.load_format,
@@ -299,7 +308,7 @@ class EngineArgs:
             max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
             and self.max_cpu_loras > 0 else None) if self.enable_lora else None
         return (model_config, cache_config, parallel_config, scheduler_config,
-                lora_config)
+                device_config, lora_config)
 
 
 @dataclass

+ 1 - 3
aphrodite/modeling/layers/activation.py

@@ -89,9 +89,7 @@ class ScaledActivation(nn.Module):
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
         self.scales = nn.Parameter(
-            torch.empty(intermediate_size_per_partition,
-                        dtype=params_dtype,
-                        device="cuda"))
+            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
         set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:

+ 1 - 1
aphrodite/modeling/layers/attention.py

@@ -200,7 +200,7 @@ def _make_alibi_bias(
     seq_len: int,
     dtype: torch.dtype,
 ) -> LowerTriangularMaskWithTensorBias:
-    bias = torch.arange(seq_len, dtype=dtype, device="cuda")
+    bias = torch.arange(seq_len, dtype=dtype)
     # NOTE: HF uses
     #     `bias = bias[None, :].repeat(prompt_len, 1)`
     # here. We find that both biases give the same results, but

+ 2 - 8
aphrodite/modeling/layers/linear.py

@@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
                        params_dtype: torch.dtype) -> Dict[str, Any]:
         weight = Parameter(torch.empty(output_size_per_partition,
                                        input_size_per_partition,
-                                       device=torch.cuda.current_device(),
                                        dtype=params_dtype),
                            requires_grad=False)
         set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
@@ -118,9 +117,7 @@ class ReplicatedLinear(torch.nn.Module):
                 self.register_parameter(name, weight)
         if bias:
             self.bias = Parameter(
-                torch.empty(self.output_size,
-                            device=torch.cuda.current_device(),
-                            dtype=self.params_dtype))
+                torch.empty(self.output_size, dtype=self.params_dtype))
             set_weight_attrs(self.bias, {"output_dim": 0})
         else:
             self.register_parameter("bias", None)
@@ -188,7 +185,6 @@ class ColumnParallelLinear(torch.nn.Module):
         if bias:
             self.bias = Parameter(
                 torch.empty(self.output_size_per_partition,
-                            device=torch.cuda.current_device(),
                             dtype=params_dtype))
             set_weight_attrs(self.bias, {
                 "output_dim": 0,
@@ -521,9 +517,7 @@ class RowParallelLinear(torch.nn.Module):
 
         if bias:
             self.bias = Parameter(
-                torch.empty(self.output_size,
-                            device=torch.cuda.current_device(),
-                            dtype=params_dtype))
+                torch.empty(self.output_size, dtype=params_dtype))
             set_weight_attrs(self.bias, {
                 "output_dim": 0,
                 "weight_loader": self.weight_loader,

+ 0 - 3
aphrodite/modeling/layers/quantization/awq.py

@@ -102,7 +102,6 @@ class AWQLinearMethod(LinearMethodBase):
             torch.empty(
                 input_size_per_partition,
                 output_size_per_partition // self.quant_config.pack_factor,
-                device="cuda",
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -118,7 +117,6 @@ class AWQLinearMethod(LinearMethodBase):
             torch.empty(
                 input_size_per_partition // self.quant_config.group_size,
                 output_size_per_partition // self.quant_config.pack_factor,
-                device="cuda",
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -134,7 +132,6 @@ class AWQLinearMethod(LinearMethodBase):
             torch.empty(
                 input_size_per_partition // self.quant_config.group_size,
                 output_size_per_partition,
-                device="cuda",
                 dtype=params_dtype,
             ),
             requires_grad=False,

+ 0 - 4
aphrodite/modeling/layers/quantization/gptq.py

@@ -135,7 +135,6 @@ class GPTQLinearMethod(LinearMethodBase):
             torch.empty(
                 input_size_per_partition // self.quant_config.pack_factor,
                 output_size_per_partition,
-                device="cuda",
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -153,7 +152,6 @@ class GPTQLinearMethod(LinearMethodBase):
                     i // self.quant_config.group_size
                     for i in range(input_size_per_partition)
                 ],
-                device="cuda",
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -164,7 +162,6 @@ class GPTQLinearMethod(LinearMethodBase):
             torch.empty(
                 scale_and_zero_size,
                 output_size_per_partition // self.quant_config.pack_factor,
-                device="cuda",
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -180,7 +177,6 @@ class GPTQLinearMethod(LinearMethodBase):
             torch.empty(
                 scale_and_zero_size,
                 output_size_per_partition,
-                device="cuda",
                 dtype=params_dtype,
             ),
             requires_grad=False,

+ 2 - 4
aphrodite/modeling/layers/quantization/squeezellm.py

@@ -86,7 +86,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
             torch.empty(
                 input_size_per_partition // self.quant_config.pack_factor,
                 output_size_per_partition,
-                device="cuda",
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -102,7 +101,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
             torch.empty(
                 output_size,
                 self.quant_config.weight_bits**2,
-                device="cuda",
                 dtype=params_dtype,
             ),
             requires_grad=False,
@@ -124,12 +122,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
         out_shape = x.shape[:-1] + (qweight.shape[-1], )
         reshaped_x = x.reshape(-1, x.shape[-1])
         if is_hip():
-            out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
+            out_f = torch.zeros(out_shape, dtype=torch.float)
             ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
             out = out_f.to(dtype=torch.float16)
         else:
             # NOTE: The output tensor should be zero-initialized.
-            out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
+            out = torch.zeros(out_shape, dtype=torch.float16)
             ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
 
         if bias is not None:

+ 11 - 13
aphrodite/modeling/layers/rotary_embedding.py

@@ -80,16 +80,13 @@ class RotaryEmbedding(nn.Module):
         # create the cache on GPU for faster initialization. This may cause
         # a slight numerical difference between the HF implementation and ours.
         inv_freq = 1.0 / (base**(torch.arange(
-            0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
-                                 self.rotary_dim))
+            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
         return inv_freq
 
     def _compute_cos_sin_cache(self) -> torch.Tensor:
         """Compute the cos and sin cache."""
         inv_freq = self._compute_inv_freq(self.base)
-        t = torch.arange(self.max_position_embeddings,
-                         dtype=torch.float,
-                         device="cuda")
+        t = torch.arange(self.max_position_embeddings, dtype=torch.float)
 
         freqs = torch.einsum("i,j -> ij", t, inv_freq)
         cos = freqs.cos()
@@ -177,7 +174,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
         # Thus, the maximum length after applying the rope scaling is
         # self.max_position_embeddings * self.scaling_factor.
         max_len = self.max_position_embeddings * self.scaling_factor
-        t = torch.arange(max_len, dtype=torch.float, device="cuda")
+        t = torch.arange(max_len, dtype=torch.float)
         t = t / self.scaling_factor
 
         freqs = torch.einsum("i,j -> ij", t, inv_freq)
@@ -217,7 +214,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
             (self.scaling_factor - 1))**(self.rotary_dim /
                                          (self.rotary_dim - 2))
         inv_freq = self._compute_inv_freq(base)
-        t = torch.arange(max_len, dtype=torch.float, device="cuda")
+        t = torch.arange(max_len, dtype=torch.float)
 
         freqs = torch.einsum("i,j -> ij", t, inv_freq)
         cos = freqs.cos()
@@ -300,9 +297,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
                          is_neox_style)
 
     def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
-        pos_freqs = self.base**(torch.arange(
-            0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
-                                self.rotary_dim)
+        pos_freqs = self.base**(
+            torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
+            self.rotary_dim)
         inv_freq_extrapolation = 1.0 / pos_freqs
         inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
 
@@ -310,9 +307,11 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
                                                 self.rotary_dim, self.base,
                                                 self.max_position_embeddings)
         # Get n-d rotational scaling corrected for extrapolation
+        # FIXME: Add device here.
+        # pylint: disable=no-value-for-parameter
         inv_freq_mask = (1 - _yarn_linear_ramp_mask(
-            low, high, self.rotary_dim // 2, dtype=torch.float,
-            device="cuda")) * self.extrapolation_factor
+            low, high, self.rotary_dim // 2,
+            dtype=torch.float)) * self.extrapolation_factor
         inv_freq = inv_freq_interpolation * (
             1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
         return inv_freq
@@ -320,7 +319,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
     def _compute_cos_sin_cache(self) -> torch.Tensor:
         inv_freq = self._compute_inv_freq(self.scaling_factor)
         t = torch.arange(self.max_position_embeddings * self.scaling_factor,
-                         device="cuda",
                          dtype=torch.float32)
         freqs = torch.einsum("i,j -> ij", t, inv_freq)
         cos = (freqs.cos() * self.mscale)

+ 0 - 1
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -152,7 +152,6 @@ class ParallelLMHead(VocabParallelEmbedding):
         if bias:
             self.bias = Parameter(
                 torch.empty(self.num_embeddings_per_partition,
-                            device=torch.cuda.current_device(),
                             dtype=params_dtype))
             set_weight_attrs(self.bias, {
                 "output_dim": 0,

+ 3 - 2
aphrodite/modeling/loader.py

@@ -6,7 +6,7 @@ import torch
 import torch.nn as nn
 from transformers import PretrainedConfig
 
-from aphrodite.common.config import ModelConfig, LoRAConfig
+from aphrodite.common.config import DeviceConfig, ModelConfig, LoRAConfig
 from aphrodite.modeling.models import ModelRegistry
 from aphrodite.modeling.hf_downloader import (get_quant_config,
                                               initialize_dummy_weights)
@@ -33,6 +33,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
 
 
 def get_model(model_config: ModelConfig,
+              device_config: DeviceConfig,
               lora_config: Optional[LoRAConfig] = None) -> nn.Module:
     model_class = _get_model_architecture(model_config.hf_config)
 
@@ -59,7 +60,7 @@ def get_model(model_config: ModelConfig,
     with _set_default_torch_dtype(model_config.dtype):
         # Create a model instance.
         # The weights will be initialized as empty tensors.
-        with torch.device("cuda"):
+        with torch.device(device_config.device):
             if getattr(model_class, "supports_lora", False):
                 model = model_class(model_config.hf_config, linear_method,
                                     lora_config)

+ 2 - 0
aphrodite/task_handler/cache_engine.py

@@ -104,11 +104,13 @@ class CacheEngine:
                 size=(self.num_cpu_blocks, *key_block_shape),
                 dtype=self.dtype,
                 pin_memory=pin_memory,
+                device="cpu",
             )
             value_blocks = torch.empty(
                 size=(self.num_cpu_blocks, *value_block_shape),
                 dtype=self.dtype,
                 pin_memory=pin_memory,
+                device="cpu",
             )
             cpu_cache.append((key_blocks, value_blocks))
         return cpu_cache

+ 39 - 25
aphrodite/task_handler/model_runner.py

@@ -5,7 +5,8 @@ import numpy as np
 import torch
 import torch.nn as nn
 
-from aphrodite.common.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
+from aphrodite.common.config import (DeviceConfig, ModelConfig, LoRAConfig,
+                                     ParallelConfig, SchedulerConfig)
 from aphrodite.common.logger import init_logger
 from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
 from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
@@ -37,6 +38,7 @@ class ModelRunner:
         model_config: ModelConfig,
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
         kv_cache_dtype: Optional[str] = "auto",
         is_driver_worker: bool = False,
@@ -51,7 +53,9 @@ class ModelRunner:
         # FIXME: This is a hack to make the tests work. Refactor this.
         self.sliding_window = (model_config.get_sliding_window()
                                if model_config is not None else None)
-        self.device = torch.device(torch.cuda.current_device())
+        self.device_config = (device_config
+                              if device_config is not None else DeviceConfig())
+        self.device = self.device_config.device
         self.model = None
         self.block_size = None  # Set after initial profiling.
         self.lora_manager = None
@@ -74,7 +78,8 @@ class ModelRunner:
         self.kv_cache_dtype = kv_cache_dtype
 
     def load_model(self) -> None:
-        self.model = get_model(self.model_config, self.lora_config)
+        self.model = get_model(self.model_config, self.device_config,
+                               self.lora_config)
 
         vocab_size = self.model.config.vocab_size
 
@@ -184,22 +189,25 @@ class ModelRunner:
         input_tokens = _make_tensor_with_pad(input_tokens,
                                              max_prompt_len,
                                              pad=0,
-                                             dtype=torch.long)
+                                             dtype=torch.long,
+                                             device=self.device)
         input_positions = _make_tensor_with_pad(input_positions,
                                                 max_prompt_len,
                                                 pad=0,
-                                                dtype=torch.long)
+                                                dtype=torch.long,
+                                                device=self.device)
         slot_mapping = _make_tensor_with_pad(slot_mapping,
                                              max_prompt_len,
                                              pad=_PAD_SLOT_ID,
-                                             dtype=torch.long)
+                                             dtype=torch.long,
+                                             device=self.device)
         lora_index_mapping = [
             _pad_to_max(mapping, max_prompt_len, pad=0)
             for mapping in lora_index_mapping
         ]
         context_lens_tensor = torch.tensor(context_lens,
                                            dtype=torch.int,
-                                           device="cuda")
+                                           device=self.device)
         # Prepare prefix block tables
         max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
         block_tables = _make_tensor_with_pad(
@@ -207,15 +215,16 @@ class ModelRunner:
             max_len=max_prompt_block_table_len,
             pad=0,
             dtype=torch.int,
+            device=self.device,
         )
         start_loc_tensor = torch.arange(0,
                                         len(prompt_lens) * max_prompt_len,
                                         max_prompt_len,
                                         dtype=torch.long,
-                                        device="cuda")
+                                        device=self.device)
         prompt_lens_tensor = torch.tensor(prompt_lens,
                                           dtype=torch.long,
-                                          device="cuda")
+                                          device=self.device)
 
         input_metadata = InputMetadata(
             is_prompt=True,
@@ -307,20 +316,20 @@ class ModelRunner:
                                              max_len=1,
                                              pad=0,
                                              dtype=torch.long,
-                                             device="cuda")
+                                             device=self.device)
         input_positions = _make_tensor_with_pad(input_positions,
                                                 max_len=1,
                                                 pad=0,
                                                 dtype=torch.long,
-                                                device="cuda")
+                                                device=self.device)
         slot_mapping = _make_tensor_with_pad(slot_mapping,
                                              max_len=1,
                                              pad=_PAD_SLOT_ID,
                                              dtype=torch.long,
-                                             device="cuda")
+                                             device=self.device)
         context_lens = torch.tensor(context_lens,
                                     dtype=torch.int,
-                                    device="cuda")
+                                    device=self.device)
 
         if use_captured_graph:
             # The shape of graph_block_tables is
@@ -329,7 +338,7 @@ class ModelRunner:
             for i, block_table in enumerate(block_tables):
                 if block_table:
                     input_block_tables[i, :len(block_table)] = block_table
-            block_tables = torch.tensor(input_block_tables, device="cuda")
+            block_tables = torch.tensor(input_block_tables, device=self.device)
         else:
             max_block_table_len = max(
                 len(block_table) for block_table in block_tables)
@@ -338,7 +347,7 @@ class ModelRunner:
                 max_len=max_block_table_len,
                 pad=0,
                 dtype=torch.int,
-                device="cuda",
+                device=self.device,
             )
 
         lora_index_mapping = [
@@ -413,9 +422,13 @@ class ModelRunner:
 
         selected_token_indices = _async_h2d(selected_token_indices,
                                             dtype=torch.long,
+                                            target_device=self.device,
                                             pin_memory=not self.in_wsl)
         categorized_sample_indices = {
-            t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
+            t: _async_h2d(seq_ids,
+                          dtype=torch.int,
+                          target_device=self.device,
+                          pin_memory=not self.in_wsl)
             for t, seq_ids in categorized_sample_indices.items()
         }
 
@@ -801,14 +814,10 @@ def _make_tensor_with_pad(
     max_len: int,
     pad: int,
     dtype: torch.dtype,
-    device: Union[str, torch.device] = "cuda",
-    pin_memory: bool = False,
+    device: Optional[Union[str, torch.device]],
 ) -> torch.Tensor:
     padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
-    return torch.tensor(padded_x,
-                        dtype=dtype,
-                        device=device,
-                        pin_memory=pin_memory and str(device) == "cpu")
+    return torch.tensor(padded_x, dtype=dtype, device=device)
 
 
 def _get_graph_batch_size(batch_size: int) -> int:
@@ -820,6 +829,11 @@ def _get_graph_batch_size(batch_size: int) -> int:
         return (batch_size + 7) // 8 * 8
 
 
-def _async_h2d(data: list, dtype, pin_memory):
-    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
-    return t.to(device="cuda", non_blocking=True)
+def _async_h2d(
+    data: list,
+    dtype: torch.dtype,
+    target_device: Union[str, torch.device],
+    pin_memory: bool,
+) -> torch.Tensor:
+    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
+    return t.to(device=target_device, non_blocking=True)

+ 22 - 15
aphrodite/task_handler/worker.py

@@ -7,7 +7,7 @@ import torch
 import torch.distributed
 
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig)
+                                     SchedulerConfig, LoRAConfig, DeviceConfig)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
                                                           )
@@ -33,6 +33,7 @@ class Worker:
         model_config: ModelConfig,
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
         local_rank: int,
         rank: int,
         distributed_init_method: str,
@@ -43,6 +44,7 @@ class Worker:
         self.model_config = model_config
         self.parallel_config = parallel_config
         self.scheduler_config = scheduler_config
+        self.device_config = device_config
         self.local_rank = local_rank
         self.rank = rank
         self.distributed_init_method = distributed_init_method
@@ -54,6 +56,7 @@ class Worker:
         self.model_runner = ModelRunner(model_config,
                                         parallel_config,
                                         scheduler_config,
+                                        device_config,
                                         lora_config=self.lora_config,
                                         kv_cache_dtype=kv_cache_dtype,
                                         is_driver_worker=is_driver_worker)
@@ -65,20 +68,24 @@ class Worker:
         self.gpu_cache = None
 
     def init_model(self) -> None:
-        # torch.distributed.all_reduce does not free the input tensor until
-        # the synchronization point. This causes the memory usage to grow
-        # as the number of all_reduce calls increases. This env var disables
-        # this behavior.
-        # Related issue:
-        # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
-        os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
-
-        # This env var set by Ray causes exceptions with graph building.
-        os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
-        self.device = torch.device(f"cuda:{self.local_rank}")
-        torch.cuda.set_device(self.device)
-
-        _check_if_gpu_supports_dtype(self.model_config.dtype)
+        if self.device_config.device.type == "cuda":
+            # torch.distributed.all_reduce does not free the input tensor until
+            # the synchronization point. This causes the memory usage to grow
+            # as the number of all_reduce calls increases. This env var disables
+            # this behavior.
+            # Related issue:
+            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
+            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
+
+            # This env var set by Ray causes exceptions with graph building.
+            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
+            self.device = torch.device(f"cuda:{self.local_rank}")
+            torch.cuda.set_device(self.device)
+
+            _check_if_gpu_supports_dtype(self.model_config.dtype)
+        else:
+            raise RuntimeError(
+                f"Not support device type: {self.device_config.device}")
 
         # Initialize the distributed environment.
         init_distributed_environment(self.parallel_config, self.rank,