|
@@ -1,5 +1,5 @@
|
|
|
from abc import abstractmethod
|
|
|
-from typing import Dict, Optional
|
|
|
+from typing import Dict, Optional, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.jit
|
|
@@ -36,9 +36,12 @@ class SpecDecodeBaseSampler(nn.Module):
|
|
|
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
|
|
self.num_draft_tokens: int = 0
|
|
|
|
|
|
- def init_gpu_tensors(self, rank: int) -> None:
|
|
|
+ def init_gpu_tensors(self, device: Union[int, str]) -> None:
|
|
|
assert self.num_accepted_tokens is None
|
|
|
- device = f"cuda:{rank}"
|
|
|
+ if isinstance(device, int):
|
|
|
+ device = f"cuda:{device}"
|
|
|
+ elif not isinstance(device, str):
|
|
|
+ raise ValueError(f"Device must be int or str, get {type(device)}")
|
|
|
self.num_accepted_tokens = torch.tensor(0,
|
|
|
dtype=torch.long,
|
|
|
device=device)
|