|
@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
|
|
|
nonzero_proposal_len_indices,
|
|
|
)
|
|
|
|
|
|
- def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
|
|
|
+ @staticmethod
|
|
|
+ def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
|
|
|
nonzero_proposal_len_indices, transposed):
|
|
|
"""Remove sequences from nonzero_proposal_len_indices and reset
|
|
|
their proposal_len to 0 the draft worker does not provide a proposal
|
|
@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
|
|
|
self,
|
|
|
batch_size: int,
|
|
|
proposal_len: int,
|
|
|
- maybe_sampler_output: Optional[SamplerOutput],
|
|
|
+ maybe_sampler_output: Optional[List[SamplerOutput]],
|
|
|
proposal_lens: List[int],
|
|
|
nonzero_proposal_len_indices: List[int],
|
|
|
sampler_transposed: bool,
|
|
@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
|
|
|
if maybe_sampler_output is None:
|
|
|
# If no speculative tokens, the sampler output will be None.
|
|
|
# In this case we return empty proposals.
|
|
|
- proposal_tokens = torch.full(
|
|
|
- size=(
|
|
|
- batch_size,
|
|
|
- proposal_len,
|
|
|
- ),
|
|
|
- fill_value=-1,
|
|
|
- dtype=torch.long,
|
|
|
- device=self._device,
|
|
|
- )
|
|
|
- proposal_probs = torch.zeros(
|
|
|
- batch_size,
|
|
|
- proposal_len,
|
|
|
- self._vocab_size,
|
|
|
- dtype=torch.float32,
|
|
|
- device=self._device,
|
|
|
- )
|
|
|
- proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
|
|
- dtype=torch.long,
|
|
|
- device=self._device)
|
|
|
+ proposal_tokens = torch.tensor(-1,
|
|
|
+ dtype=torch.long,
|
|
|
+ device=self._device).expand(
|
|
|
+ batch_size, proposal_len)
|
|
|
+ proposal_probs = torch.tensor(0,
|
|
|
+ dtype=torch.float32,
|
|
|
+ device=self._device).expand(
|
|
|
+ batch_size, proposal_len,
|
|
|
+ self._vocab_size)
|
|
|
+ proposal_lens_tensor = torch.tensor(0,
|
|
|
+ dtype=torch.long,
|
|
|
+ device=self._device).expand(
|
|
|
+ len(proposal_lens))
|
|
|
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
|
|
|
|
|
sampler_output = maybe_sampler_output
|
|
@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
|
|
|
# Now, reformat the output GPU tensors such that each sequence has
|
|
|
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
|
|
|
|
|
- entire_proposal_tokens = torch.full(
|
|
|
+ entire_proposal_tokens = proposal_tokens.new_full(
|
|
|
size=(batch_size, *proposal_tokens.shape[1:]),
|
|
|
fill_value=-1,
|
|
|
- dtype=torch.long,
|
|
|
- device=self._device,
|
|
|
)
|
|
|
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
|
|
- entire_proposal_probs = torch.zeros(
|
|
|
+ entire_proposal_probs = proposal_probs.new_zeros(
|
|
|
batch_size,
|
|
|
*proposal_probs.shape[1:],
|
|
|
- dtype=torch.float32,
|
|
|
- device=self._device,
|
|
|
)
|
|
|
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
|
|
|