|
@@ -81,7 +81,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|
|
|
|
|
target_sampler_output = self._scorer_worker.execute_model(
|
|
|
execute_model_req=execute_model_req.clone(
|
|
|
- seq_group_metadata_list=target_seq_group_metadata_list, ))
|
|
|
+ seq_group_metadata_list=target_seq_group_metadata_list))
|
|
|
assert len(target_sampler_output) == 1, "expected single-step output"
|
|
|
target_sampler_output = target_sampler_output[0]
|
|
|
|
|
@@ -141,8 +141,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|
|
num_scoring_tokens)
|
|
|
|
|
|
def _contract_batch(
|
|
|
- self, contracted_bs: int,
|
|
|
- target_sampler_output: List[SamplerOutput],
|
|
|
+ self, contracted_bs: int, target_sampler_output: SamplerOutput,
|
|
|
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
|
|
non_spec_indices: List[int], spec_indices: List[int],
|
|
|
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
@@ -168,30 +167,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|
|
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
|
|
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
|
|
|
|
|
- target_token_ids = target_token_ids.squeeze().reshape(
|
|
|
- spec_expanded_bs, k + 1)
|
|
|
- target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
|
|
- self._vocab_size)
|
|
|
- target_logprobs = target_logprobs.squeeze().reshape(
|
|
|
- spec_expanded_bs, k + 1, self._vocab_size)
|
|
|
-
|
|
|
- all_tokens = torch.full(size=(contracted_bs, k + 1),
|
|
|
- fill_value=-1,
|
|
|
- device=self._device,
|
|
|
- dtype=torch.long)
|
|
|
- all_probs = torch.zeros(contracted_bs,
|
|
|
- k + 1,
|
|
|
- self._vocab_size,
|
|
|
- device=self._device,
|
|
|
- dtype=torch.float32)
|
|
|
- all_logprobs = torch.full(size=(
|
|
|
- contracted_bs,
|
|
|
- k + 1,
|
|
|
- self._vocab_size,
|
|
|
- ),
|
|
|
- fill_value=-float("inf"),
|
|
|
- device=self._device,
|
|
|
- dtype=torch.float32)
|
|
|
+ target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
|
|
+ target_probs = target_probs.reshape(*target_token_ids.shape,
|
|
|
+ self._vocab_size)
|
|
|
+ target_logprobs = target_logprobs.reshape(target_probs.shape)
|
|
|
+
|
|
|
+ all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
|
|
+ fill_value=-1)
|
|
|
+ all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
|
|
|
+ all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
|
|
+ fill_value=-float("inf"))
|
|
|
|
|
|
if non_spec_indices:
|
|
|
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|