Przeglądaj źródła

feat: Add support for GPU device selection in SpecDecodeBaseSampler (#629)

AlpinDale 6 miesięcy temu
rodzic
commit
09b82f9963

+ 6 - 3
aphrodite/modeling/layers/spec_decode_base_sampler.py

@@ -1,5 +1,5 @@
 from abc import abstractmethod
-from typing import Dict, Optional
+from typing import Dict, Optional, Union
 
 import torch
 import torch.jit
@@ -36,9 +36,12 @@ class SpecDecodeBaseSampler(nn.Module):
         self.num_emitted_tokens: Optional[torch.Tensor] = None
         self.num_draft_tokens: int = 0
 
-    def init_gpu_tensors(self, rank: int) -> None:
+    def init_gpu_tensors(self, device: Union[int, str]) -> None:
         assert self.num_accepted_tokens is None
-        device = f"cuda:{rank}"
+        if isinstance(device, int):
+            device = f"cuda:{device}"
+        elif not isinstance(device, str):
+            raise ValueError(f"Device must be int or str, get {type(device)}")
         self.num_accepted_tokens = torch.tensor(0,
                                                 dtype=torch.long,
                                                 device=device)