Explorar o código

chore: minor simplifications

AlpinDale hai 7 meses
pai
achega
4d1e613804

+ 1 - 1
aphrodite/engine/output_processor/multi_step.py

@@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
 
         # Since there's only one sequence per sequence group, we can take the
         # first sample.
-        samples = [outputs[step].samples[0] for step in range(len(outputs))]
+        samples = [output.samples[0] for output in outputs]
 
         # -1 means the output token is not valid (eg. due to spec decode
         # rejecting tokens).

+ 4 - 2
aphrodite/modeling/layers/rejection.py

@@ -301,8 +301,10 @@ class RejectionSampler(nn.Module):
 
         # Fill in the first k columns of the output tensor using masks and data
         # tensors.
-        output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-                                    -torch.ones_like(draft_token_ids))
+        torch.where(accepted_mask,
+                    draft_token_ids,
+                    -torch.ones_like(draft_token_ids),
+                    out=output)
 
         # Fill the last column.
         # We check output directly as accepted may have True values inconsistent

+ 12 - 27
aphrodite/spec_decode/batch_expansion.py

@@ -81,7 +81,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
 
         target_sampler_output = self._scorer_worker.execute_model(
             execute_model_req=execute_model_req.clone(
-                seq_group_metadata_list=target_seq_group_metadata_list, ))
+                seq_group_metadata_list=target_seq_group_metadata_list))
         assert len(target_sampler_output) == 1, "expected single-step output"
         target_sampler_output = target_sampler_output[0]
 
@@ -141,8 +141,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
                 num_scoring_tokens)
 
     def _contract_batch(
-            self, contracted_bs: int,
-            target_sampler_output: List[SamplerOutput],
+            self, contracted_bs: int, target_sampler_output: SamplerOutput,
             proposals: SpeculativeProposals, num_scoring_tokens: int,
             non_spec_indices: List[int], spec_indices: List[int],
             k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -168,30 +167,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
         non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
         spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
 
-        target_token_ids = target_token_ids.squeeze().reshape(
-            spec_expanded_bs, k + 1)
-        target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
-                                                      self._vocab_size)
-        target_logprobs = target_logprobs.squeeze().reshape(
-            spec_expanded_bs, k + 1, self._vocab_size)
-
-        all_tokens = torch.full(size=(contracted_bs, k + 1),
-                                fill_value=-1,
-                                device=self._device,
-                                dtype=torch.long)
-        all_probs = torch.zeros(contracted_bs,
-                                k + 1,
-                                self._vocab_size,
-                                device=self._device,
-                                dtype=torch.float32)
-        all_logprobs = torch.full(size=(
-            contracted_bs,
-            k + 1,
-            self._vocab_size,
-        ),
-                                  fill_value=-float("inf"),
-                                  device=self._device,
-                                  dtype=torch.float32)
+        target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
+        target_probs = target_probs.reshape(*target_token_ids.shape,
+                                            self._vocab_size)
+        target_logprobs = target_logprobs.reshape(target_probs.shape)
+
+        all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
+                                               fill_value=-1)
+        all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
+        all_logprobs = target_logprobs.new_full(size=all_probs.shape,
+                                                fill_value=-float("inf"))
 
         if non_spec_indices:
             all_tokens[non_spec_indices, :1] = non_spec_target_token_ids

+ 7 - 7
aphrodite/spec_decode/spec_decode_worker.py

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
 import torch
 from loguru import logger
 
+from aphrodite.common.config import SpeculativeConfig
 from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
                                        SequenceGroupMetadata)
 from aphrodite.distributed.communication_op import broadcast_tensor_dict
@@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
     WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
     """
     assert "speculative_config" in kwargs
-    speculative_config = kwargs.get("speculative_config")
+    speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
     assert speculative_config is not None
 
     target_worker = Worker(*args, **kwargs)
@@ -105,12 +106,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         logger.info("Configuring SpecDecodeWorker with "
                     f"proposer={type(proposer_worker)}")
 
-        return SpecDecodeWorker(
-            proposer_worker,
-            scorer_worker,
-            disable_by_batch_size=disable_by_batch_size,
-            rejection_sampler=RejectionSampler(
-                disable_bonus_tokens=disable_bonus_tokens, ))
+        return SpecDecodeWorker(proposer_worker,
+                                scorer_worker,
+                                disable_by_batch_size=disable_by_batch_size,
+                                rejection_sampler=RejectionSampler(
+                                    disable_bonus_tokens=disable_bonus_tokens))
 
     def __init__(
         self,

+ 18 - 27
aphrodite/spec_decode/top1_proposer.py

@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
             nonzero_proposal_len_indices,
         )
 
-    def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
+    @staticmethod
+    def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
                                  nonzero_proposal_len_indices, transposed):
         """Remove sequences from nonzero_proposal_len_indices and reset
         their proposal_len to 0 the draft worker does not provide a proposal
@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
         self,
         batch_size: int,
         proposal_len: int,
-        maybe_sampler_output: Optional[SamplerOutput],
+        maybe_sampler_output: Optional[List[SamplerOutput]],
         proposal_lens: List[int],
         nonzero_proposal_len_indices: List[int],
         sampler_transposed: bool,
@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
         if maybe_sampler_output is None:
             # If no speculative tokens, the sampler output will be None.
             # In this case we return empty proposals.
-            proposal_tokens = torch.full(
-                size=(
-                    batch_size,
-                    proposal_len,
-                ),
-                fill_value=-1,
-                dtype=torch.long,
-                device=self._device,
-            )
-            proposal_probs = torch.zeros(
-                batch_size,
-                proposal_len,
-                self._vocab_size,
-                dtype=torch.float32,
-                device=self._device,
-            )
-            proposal_lens_tensor = torch.zeros(len(proposal_lens),
-                                               dtype=torch.long,
-                                               device=self._device)
+            proposal_tokens = torch.tensor(-1,
+                                           dtype=torch.long,
+                                           device=self._device).expand(
+                                               batch_size, proposal_len)
+            proposal_probs = torch.tensor(0,
+                                          dtype=torch.float32,
+                                          device=self._device).expand(
+                                              batch_size, proposal_len,
+                                              self._vocab_size)
+            proposal_lens_tensor = torch.tensor(0,
+                                                dtype=torch.long,
+                                                device=self._device).expand(
+                                                    len(proposal_lens))
             return proposal_tokens, proposal_probs, proposal_lens_tensor
 
         sampler_output = maybe_sampler_output
@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
         # Now, reformat the output GPU tensors such that each sequence has
         # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
 
-        entire_proposal_tokens = torch.full(
+        entire_proposal_tokens = proposal_tokens.new_full(
             size=(batch_size, *proposal_tokens.shape[1:]),
             fill_value=-1,
-            dtype=torch.long,
-            device=self._device,
         )
         entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
-        entire_proposal_probs = torch.zeros(
+        entire_proposal_probs = proposal_probs.new_zeros(
             batch_size,
             *proposal_probs.shape[1:],
-            dtype=torch.float32,
-            device=self._device,
         )
         entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
 

+ 3 - 8
aphrodite/spec_decode/util.py

@@ -1,12 +1,11 @@
 from contextlib import contextmanager
-from itertools import chain
 from typing import Dict, List, Tuple
 
 import torch
 
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
                                        SamplerOutput, SequenceGroupMetadata,
-                                       SequenceGroupOutput, SequenceOutput)
+                                       SequenceOutput)
 
 SeqId = int
 
@@ -16,11 +15,7 @@ def get_all_seq_ids(
     """Given a list of SequenceGroupMetadata, create a list of all
     sequence ids.
     """
-    return list(
-        chain.from_iterable([
-            seq_group_metadata.seq_data.keys()
-            for seq_group_metadata in seq_group_metadata_list
-        ]))
+    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
 
 
 def get_all_num_logprobs(
@@ -68,7 +63,7 @@ def create_sequence_group_output(
     seq_id: SeqId,
     topk_token_ids: List[int],
     topk_logprobs: List[float],
-) -> SequenceGroupOutput:
+) -> CompletionSequenceGroupOutput:
     """Create a SequenceGroupOutput given the sampling results.
 
     Args:

+ 15 - 16
aphrodite/task_handler/model_runner.py

@@ -547,17 +547,6 @@ class ModelRunner:
             )
         assert max_query_len > 0, ("query_lens: {}".format(query_lens))
 
-        context_lens_tensor = torch.tensor(context_lens,
-                                           dtype=torch.int,
-                                           device=self.device)
-
-        query_lens_tensor = torch.tensor(query_lens,
-                                         dtype=torch.long,
-                                         device=self.device)
-        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
-                                      dtype=torch.int32,
-                                      device=self.device)
-
         seq_lens_tensor = torch.tensor(seq_lens,
                                        dtype=torch.int,
                                        device=self.device)
@@ -565,11 +554,6 @@ class ModelRunner:
                                     dtype=torch.int32,
                                     device=self.device)
 
-        torch.cumsum(query_lens_tensor,
-                     dim=0,
-                     dtype=query_start_loc.dtype,
-                     out=query_start_loc[1:])
-
         torch.cumsum(seq_lens_tensor,
                      dim=0,
                      dtype=seq_start_loc.dtype,
@@ -622,6 +606,21 @@ class ModelRunner:
                 seq_start_loc=seq_start_loc,
                 data_type=kv_cache_dtype)
         else:
+            context_lens_tensor = torch.tensor(context_lens,
+                                               dtype=torch.int,
+                                               device=self.device)
+            query_lens_tensor = torch.tensor(query_lens,
+                                             dtype=torch.long,
+                                             device=self.device)
+            query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
+                                          dtype=torch.int32,
+                                          device=self.device)
+
+            torch.cumsum(query_lens_tensor,
+                         dim=0,
+                         dtype=query_start_loc.dtype,
+                         out=query_start_loc[1:])
+
             attn_metadata = self.attn_backend.make_metadata(
                 num_prefills=num_prefills,
                 slot_mapping=slot_mapping_tensor,

+ 2 - 2
aphrodite/transformers_utils/config.py

@@ -1,5 +1,5 @@
 import os
-from typing import Dict, Optional
+from typing import Dict, Optional, Type
 
 from loguru import logger
 from transformers import PretrainedConfig
@@ -8,7 +8,7 @@ from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
                                                   JAISConfig, MPTConfig,
                                                   RWConfig)
 
-_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
+_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
     "chatglm": ChatGLMConfig,
     "dbrx": DbrxConfig,
     "mpt": MPTConfig,