Browse Source

fix: loading chameleon model with TP>1 (#695)

AlpinDale 6 months ago
parent
commit
0e558e9b2f
55 changed files with 313 additions and 134 deletions
  1. 5 3
      aphrodite/common/outputs.py
  2. 1 1
      aphrodite/distributed/communication_op.py
  3. 1 1
      aphrodite/distributed/parallel_state.py
  4. 8 4
      aphrodite/modeling/layers/logits_processor.py
  5. 5 2
      aphrodite/modeling/model_loader/neuron.py
  6. 5 2
      aphrodite/modeling/model_loader/openvino.py
  7. 24 2
      aphrodite/modeling/model_loader/weight_utils.py
  8. 5 2
      aphrodite/modeling/models/arctic.py
  9. 5 2
      aphrodite/modeling/models/baichuan.py
  10. 5 2
      aphrodite/modeling/models/bart.py
  11. 5 2
      aphrodite/modeling/models/blip2.py
  12. 5 2
      aphrodite/modeling/models/bloom.py
  13. 18 5
      aphrodite/modeling/models/chameleon.py
  14. 5 2
      aphrodite/modeling/models/chatglm.py
  15. 10 19
      aphrodite/modeling/models/commandr.py
  16. 5 2
      aphrodite/modeling/models/dbrx.py
  17. 5 2
      aphrodite/modeling/models/deepseek.py
  18. 5 2
      aphrodite/modeling/models/deepseek_v2.py
  19. 5 2
      aphrodite/modeling/models/falcon.py
  20. 5 2
      aphrodite/modeling/models/fuyu.py
  21. 5 2
      aphrodite/modeling/models/gemma.py
  22. 5 2
      aphrodite/modeling/models/gemma2.py
  23. 5 2
      aphrodite/modeling/models/gpt2.py
  24. 5 2
      aphrodite/modeling/models/gpt_bigcode.py
  25. 5 2
      aphrodite/modeling/models/gpt_j.py
  26. 5 2
      aphrodite/modeling/models/gpt_neox.py
  27. 5 2
      aphrodite/modeling/models/internlm2.py
  28. 5 2
      aphrodite/modeling/models/internvl.py
  29. 5 2
      aphrodite/modeling/models/jais.py
  30. 5 2
      aphrodite/modeling/models/jamba.py
  31. 5 2
      aphrodite/modeling/models/llama.py
  32. 5 2
      aphrodite/modeling/models/llava.py
  33. 5 2
      aphrodite/modeling/models/llava_next.py
  34. 5 2
      aphrodite/modeling/models/mamba.py
  35. 11 5
      aphrodite/modeling/models/medusa.py
  36. 5 2
      aphrodite/modeling/models/minicpm.py
  37. 5 2
      aphrodite/modeling/models/minicpmv.py
  38. 5 2
      aphrodite/modeling/models/mixtral.py
  39. 5 2
      aphrodite/modeling/models/mixtral_quant.py
  40. 5 2
      aphrodite/modeling/models/mpt.py
  41. 5 2
      aphrodite/modeling/models/nemotron.py
  42. 5 2
      aphrodite/modeling/models/olmo.py
  43. 5 2
      aphrodite/modeling/models/opt.py
  44. 5 2
      aphrodite/modeling/models/orion.py
  45. 5 2
      aphrodite/modeling/models/paligemma.py
  46. 5 2
      aphrodite/modeling/models/persimmon.py
  47. 5 2
      aphrodite/modeling/models/phi.py
  48. 5 2
      aphrodite/modeling/models/phi3_small.py
  49. 5 2
      aphrodite/modeling/models/phi3v.py
  50. 5 2
      aphrodite/modeling/models/qwen.py
  51. 5 2
      aphrodite/modeling/models/qwen2.py
  52. 5 2
      aphrodite/modeling/models/qwen2_moe.py
  53. 5 2
      aphrodite/modeling/models/stablelm.py
  54. 5 2
      aphrodite/modeling/models/starcoder2.py
  55. 5 2
      aphrodite/modeling/models/xverse.py

+ 5 - 3
aphrodite/common/outputs.py

@@ -1,6 +1,8 @@
 import time
 import time
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional
+from typing import Sequence as GenericSequence
+from typing import Union
 
 
 from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
 from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
                                        SampleLogprobs, SequenceGroup,
                                        SampleLogprobs, SequenceGroup,
@@ -29,7 +31,7 @@ class CompletionOutput:
 
 
     index: int
     index: int
     text: str
     text: str
-    token_ids: Tuple[int, ...]
+    token_ids: GenericSequence[int]
     cumulative_logprob: Optional[float]
     cumulative_logprob: Optional[float]
     logprobs: Optional[SampleLogprobs]
     logprobs: Optional[SampleLogprobs]
     finish_reason: Optional[str] = None
     finish_reason: Optional[str] = None
@@ -139,7 +141,7 @@ class RequestOutput:
             CompletionOutput(
             CompletionOutput(
                 seqs.index(seq),
                 seqs.index(seq),
                 seq.get_output_text_to_return(text_buffer_length),
                 seq.get_output_text_to_return(text_buffer_length),
-                seq.data._output_token_ids, # type: ignore
+                seq.data._output_token_ids,
                 seq.get_cumulative_logprob() if include_logprobs else None,
                 seq.get_cumulative_logprob() if include_logprobs else None,
                 seq.output_logprobs if include_logprobs else None,
                 seq.output_logprobs if include_logprobs else None,
                 SequenceStatus.get_finished_reason(seq.status),
                 SequenceStatus.get_finished_reason(seq.status),

+ 1 - 1
aphrodite/distributed/communication_op.py

@@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
 
 
 def tensor_model_parallel_gather(input_: torch.Tensor,
 def tensor_model_parallel_gather(input_: torch.Tensor,
                                  dst: int = 0,
                                  dst: int = 0,
-                                 dim: int = -1) -> torch.Tensor:
+                                 dim: int = -1) -> Optional[torch.Tensor]:
     """Gather the input tensor across model parallel group."""
     """Gather the input tensor across model parallel group."""
     return get_tp_group().gather(input_, dst, dim)
     return get_tp_group().gather(input_, dst, dim)
 
 

+ 1 - 1
aphrodite/distributed/parallel_state.py

@@ -329,7 +329,7 @@ class GroupCoordinator:
     def gather(self,
     def gather(self,
                input_: torch.Tensor,
                input_: torch.Tensor,
                dst: int = 0,
                dst: int = 0,
-               dim: int = -1) -> torch.Tensor:
+               dim: int = -1) -> Optional[torch.Tensor]:
         """
         """
         NOTE: We assume that the input tensor is on the same device across
         NOTE: We assume that the input tensor is on the same device across
         all the ranks.
         all the ranks.

+ 8 - 4
aphrodite/modeling/layers/logits_processor.py

@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
         hidden_states: torch.Tensor,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
         sampling_metadata: SamplingMetadata,
         embedding_bias: Optional[torch.Tensor] = None,
         embedding_bias: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
+    ) -> Optional[torch.Tensor]:
         if self.logits_as_input:
         if self.logits_as_input:
             logits = hidden_states
             logits = hidden_states
         else:
         else:
@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
 
 
         return logits
         return logits
 
 
-    def _get_logits(self, hidden_states: torch.Tensor,
-                    lm_head: VocabParallelEmbedding,
-                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
+    def _get_logits(
+        self,
+        hidden_states: torch.Tensor,
+        lm_head: VocabParallelEmbedding,
+        embedding_bias: Optional[torch.Tensor],
+    ) -> Optional[torch.Tensor]:
         # Get the logits for the next tokens.
         # Get the logits for the next tokens.
         logits = lm_head.linear_method.apply(lm_head,
         logits = lm_head.linear_method.apply(lm_head,
                                              hidden_states,
                                              hidden_states,
                                              bias=embedding_bias)
                                              bias=embedding_bias)
         if self.use_gather:
         if self.use_gather:
+            # None may be returned for rank > 0
             logits = tensor_model_parallel_gather(logits)
             logits = tensor_model_parallel_gather(logits)
         else:
         else:
             # Gather is not supported for some devices such as TPUs.
             # Gather is not supported for some devices such as TPUs.

+ 5 - 2
aphrodite/modeling/model_loader/neuron.py

@@ -62,8 +62,11 @@ class NeuronCasualLM(nn.Module):
                             start_ids=input_block_ids)
                             start_ids=input_block_ids)
         return logits
         return logits
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(None, hidden_states, sampling_metadata)
         logits = self.logits_processor(None, hidden_states, sampling_metadata)
         return logits
         return logits
 
 

+ 5 - 2
aphrodite/modeling/model_loader/openvino.py

@@ -179,8 +179,11 @@ class OpenVINOCasualLM(nn.Module):
         # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension
         # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension
         return logits.view(-1, logits.shape[-1])
         return logits.view(-1, logits.shape[-1])
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
         hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
         logits = self.logits_processor(None, hidden_states, sampling_metadata)
         logits = self.logits_processor(None, hidden_states, sampling_metadata)
         return logits
         return logits

+ 24 - 2
aphrodite/modeling/model_loader/weight_utils.py

@@ -501,8 +501,30 @@ def default_weight_loader(param: torch.Tensor,
                           loaded_weight: torch.Tensor) -> None:
                           loaded_weight: torch.Tensor) -> None:
     """Default weight loader."""
     """Default weight loader."""
 
 
-    assert param.size() == loaded_weight.size()
-    param.data.copy_(loaded_weight)
+    try:
+        assert param.size() == loaded_weight.size(), (
+            f"Attempted to load weight ({loaded_weight.size()}) "
+            f"into parameter ({param.size()})")
+
+        param.data.copy_(loaded_weight)
+    except Exception:
+        # NOTE: This exception is added for the purpose of setting breakpoint to
+        # debug weight loading issues.
+        raise
+
+
+def row_parallel_weight_loader(param: torch.Tensor,
+                               loaded_weight: torch.Tensor) -> None:
+    """Load weights that are row-parallelized."""
+    tp_rank = get_tensor_model_parallel_rank()
+    shard_dim = 0 if param.dim() != 1 else None
+
+    if shard_dim is not None:
+        shard_size = param.data.shape[shard_dim]
+        start_idx = tp_rank * shard_size
+        loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
+
+    return default_weight_loader(param, loaded_weight)
 
 
 
 
 def initialize_dummy_weights(
 def initialize_dummy_weights(

+ 5 - 2
aphrodite/modeling/models/arctic.py

@@ -431,8 +431,11 @@ class ArcticForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/baichuan.py

@@ -344,8 +344,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/bart.py

@@ -868,8 +868,11 @@ class BartForConditionalGeneration(nn.Module):
         return self.model(input_ids, positions, encoder_input_ids,
         return self.model(input_ids, positions, encoder_input_ids,
                           encoder_positions, kv_caches, attn_metadata)
                           encoder_positions, kv_caches, attn_metadata)
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/blip2.py

@@ -631,8 +631,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
 
 
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.get_lm_head(), hidden_states,
         logits = self.logits_processor(self.get_lm_head(), hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/bloom.py

@@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 18 - 5
aphrodite/modeling/models/chameleon.py

@@ -25,8 +25,10 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, row_parallel_weight_loader)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal.image import (cached_get_tokenizer,
 from aphrodite.multimodal.image import (cached_get_tokenizer,
                                         repeat_and_pad_image_tokens)
                                         repeat_and_pad_image_tokens)
@@ -138,6 +140,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
         super().__init__(hidden_size, *args, **kwargs)
         super().__init__(hidden_size, *args, **kwargs)
         self.normalized_shape = (hidden_size[-1], )
         self.normalized_shape = (hidden_size[-1], )
 
 
+        set_weight_attrs(self.weight,
+                         {"weight_loader": row_parallel_weight_loader})
+        set_weight_attrs(self.bias,
+                         {"weight_loader": row_parallel_weight_loader})
+
     def forward(self, hidden_states):
     def forward(self, hidden_states):
         hidden_states = F.layer_norm(hidden_states,
         hidden_states = F.layer_norm(hidden_states,
                                      self.normalized_shape,
                                      self.normalized_shape,
@@ -694,6 +701,8 @@ class ChameleonVQVAEEncoder(nn.Module):
         )
         )
 
 
     def forward(self, pixel_values: torch.Tensor):
     def forward(self, pixel_values: torch.Tensor):
+        pixel_values = pixel_values.to(self.conv_in.weight.dtype)
+
         # downsampling
         # downsampling
         hidden_states = [self.conv_in(pixel_values)]
         hidden_states = [self.conv_in(pixel_values)]
         for i_level in range(self.num_resolutions):
         for i_level in range(self.num_resolutions):
@@ -956,15 +965,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
 
 
         # Disallow image tokens which does not include special
         # Disallow image tokens which does not include special
         # begin-image and end-image tokens
         # begin-image and end-image tokens
-        image_tokens = self.model.vocabulary_mapping.image_tokens
-        logits[:, image_tokens] = torch.finfo(logits.dtype).min
+        if logits is not None:
+            image_tokens = self.model.vocabulary_mapping.image_tokens
+            logits[:, image_tokens] = torch.finfo(logits.dtype).min
 
 
         return logits
         return logits
 
 

+ 5 - 2
aphrodite/modeling/models/chatglm.py

@@ -370,8 +370,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 10 - 19
aphrodite/modeling/models/commandr.py

@@ -25,15 +25,13 @@ from typing import Iterable, List, Optional, Tuple
 import torch
 import torch
 import torch.utils.checkpoint
 import torch.utils.checkpoint
 from torch import nn
 from torch import nn
-from torch.nn.parameter import Parameter
 from transformers import CohereConfig
 from transformers import CohereConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.utils import progress_bar
 from aphrodite.common.utils import progress_bar
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size)
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
@@ -43,7 +41,8 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
-from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, row_parallel_weight_loader)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.base_config import QuantizationConfig
@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
         super().__init__()
         super().__init__()
         self.weight = nn.Parameter(torch.ones(param_shape))
         self.weight = nn.Parameter(torch.ones(param_shape))
         self.variance_epsilon = eps
         self.variance_epsilon = eps
-        set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
+        set_weight_attrs(self.weight,
+                         {"weight_loader": row_parallel_weight_loader})
 
 
     def forward(self, hidden_states, residuals=None):
     def forward(self, hidden_states, residuals=None):
         hidden_states = layer_norm_func(hidden_states, self.weight,
         hidden_states = layer_norm_func(hidden_states, self.weight,
                                         self.variance_epsilon)
                                         self.variance_epsilon)
         return hidden_states, residuals
         return hidden_states, residuals
 
 
-    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
-        tp_rank = get_tensor_model_parallel_rank()
-        shard_dim = 0 if param.dim() != 1 else None
-        param_data = param.data
-        if shard_dim is not None:
-            shard_size = param_data.shape[shard_dim]
-            start_idx = tp_rank * shard_size
-            loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
-                                                 shard_size)
-        assert param_data.shape == loaded_weight.shape
-        param_data.copy_(loaded_weight)
-
 
 
 # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
 # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
 class CohereMLP(nn.Module):
 class CohereMLP(nn.Module):
@@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         is_not_lora = hasattr(self.model.embed_tokens, 'weight')
         is_not_lora = hasattr(self.model.embed_tokens, 'weight')
         if is_not_lora:
         if is_not_lora:
             logits = self.logits_processor(self.model.embed_tokens,
             logits = self.logits_processor(self.model.embed_tokens,

+ 5 - 2
aphrodite/modeling/models/dbrx.py

@@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/deepseek.py

@@ -395,8 +395,11 @@ class DeepseekForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/deepseek_v2.py

@@ -456,8 +456,11 @@ class DeepseekV2ForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/falcon.py

@@ -395,8 +395,11 @@ class FalconForCausalLM(nn.Module):
         )
         )
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/fuyu.py

@@ -285,8 +285,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
         )
         )
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.language_model.logits_processor(
         logits = self.language_model.logits_processor(
             self.language_model.lm_head, hidden_states, sampling_metadata)
             self.language_model.lm_head, hidden_states, sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/gemma.py

@@ -350,8 +350,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.model.embed_tokens, hidden_states,
         logits = self.logits_processor(self.model.embed_tokens, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/gemma2.py

@@ -342,8 +342,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.model.embed_tokens, hidden_states,
         logits = self.logits_processor(self.model.embed_tokens, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/gpt2.py

@@ -234,8 +234,11 @@ class GPT2LMHeadModel(nn.Module):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/gpt_bigcode.py

@@ -253,8 +253,11 @@ class GPTBigCodeForCausalLM(nn.Module):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/gpt_j.py

@@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata, self.lm_head.bias)
                                        sampling_metadata, self.lm_head.bias)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/gpt_neox.py

@@ -258,8 +258,11 @@ class GPTNeoXForCausalLM(nn.Module):
                                       attn_metadata)
                                       attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.embed_out, hidden_states,
         logits = self.logits_processor(self.embed_out, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/internlm2.py

@@ -279,8 +279,11 @@ class InternLM2ForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.output, hidden_states,
         logits = self.logits_processor(self.output, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/internvl.py

@@ -461,8 +461,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
                                                   inputs_embeds=inputs_embeds)
                                                   inputs_embeds=inputs_embeds)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         return self.language_model.compute_logits(hidden_states,
         return self.language_model.compute_logits(hidden_states,
                                                   sampling_metadata)
                                                   sampling_metadata)
 
 

+ 5 - 2
aphrodite/modeling/models/jais.py

@@ -295,8 +295,11 @@ class JAISLMHeadModel(nn.Module):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/jamba.py

@@ -684,8 +684,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
         )
         )
         return conv_state_shape, temporal_state_shape
         return conv_state_shape, temporal_state_shape
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/llama.py

@@ -425,8 +425,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
                                   attn_metadata, intermediate_tensors)
                                   attn_metadata, intermediate_tensors)
         return model_output
         return model_output
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/llava.py

@@ -353,8 +353,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
 
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         return self.language_model.compute_logits(hidden_states,
         return self.language_model.compute_logits(hidden_states,
                                                   sampling_metadata)
                                                   sampling_metadata)
 
 

+ 5 - 2
aphrodite/modeling/models/llava_next.py

@@ -583,8 +583,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
 
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         return self.language_model.compute_logits(hidden_states,
         return self.language_model.compute_logits(hidden_states,
                                                   sampling_metadata)
                                                   sampling_metadata)
 
 

+ 5 - 2
aphrodite/modeling/models/mamba.py

@@ -490,8 +490,11 @@ class MambaForCausalLM(nn.Module, HasInnerState):
         )
         )
         return conv_state_shape, temporal_state_shape
         return conv_state_shape, temporal_state_shape
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 11 - 5
aphrodite/modeling/models/medusa.py

@@ -66,22 +66,28 @@ class Medusa(nn.Module):
     def compute_logits(
     def compute_logits(
             self, hidden_states: List[torch.Tensor],
             self, hidden_states: List[torch.Tensor],
             sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
             sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
-        logits = []
+        logits_lst: List[torch.Tensor] = []
 
 
         for hs, lm_head in zip(hidden_states, self.lm_heads):
         for hs, lm_head in zip(hidden_states, self.lm_heads):
             _logits = self.logits_processor(lm_head, hs, sampling_metadata)
             _logits = self.logits_processor(lm_head, hs, sampling_metadata)
 
 
+            if _logits is None:
+                # _logits should only be None on rank > 0, in which case
+                # it should remain true for every lm_head
+                assert len(logits_lst) == 0
+                continue
+
             if self.token_map is None:
             if self.token_map is None:
-                logits.append(_logits)
+                logits_lst.append(_logits)
             else:
             else:
-                logits.append(-torch.inf * torch.ones(
+                logits_lst.append(-torch.inf * torch.ones(
                     size=(*_logits.shape[:-1], self.orig_vocab_size),
                     size=(*_logits.shape[:-1], self.orig_vocab_size),
                     device=_logits.device,
                     device=_logits.device,
                     dtype=_logits.dtype))
                     dtype=_logits.dtype))
 
 
-                logits[-1][..., self.token_map] = _logits
+                logits_lst[-1][..., self.token_map] = _logits
 
 
-        return logits
+        return logits_lst
 
 
     def sample(
     def sample(
         self,
         self,

+ 5 - 2
aphrodite/modeling/models/minicpm.py

@@ -468,8 +468,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata, intermediate_tensors)
                                    attn_metadata, intermediate_tensors)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         hidden_states = hidden_states / self.scale_width
         hidden_states = hidden_states / self.scale_width
         if self.config.tie_word_embeddings:
         if self.config.tie_word_embeddings:
             lm_head = self.model.embed_tokens
             lm_head = self.model.embed_tokens

+ 5 - 2
aphrodite/modeling/models/minicpmv.py

@@ -631,8 +631,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
         )
         )
         return output
         return output
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/mixtral.py

@@ -375,8 +375,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata, intermediate_tensors)
                                    attn_metadata, intermediate_tensors)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/mixtral_quant.py

@@ -362,8 +362,11 @@ class MixtralForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/mpt.py

@@ -279,8 +279,11 @@ class MPTForCausalLM(nn.Module):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/nemotron.py

@@ -455,8 +455,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
                                   attn_metadata, intermediate_tensors)
                                   attn_metadata, intermediate_tensors)
         return model_output
         return model_output
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/olmo.py

@@ -311,8 +311,11 @@ class OlmoForCausalLM(nn.Module):
         )
         )
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/opt.py

@@ -323,8 +323,11 @@ class OPTForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/orion.py

@@ -277,8 +277,11 @@ class OrionForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/paligemma.py

@@ -258,8 +258,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
         return hidden_states
         return hidden_states
 
 
     # Copied from vllm/modeling/models/gemma.py
     # Copied from vllm/modeling/models/gemma.py
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.language_model.embed_tokens,
         logits = self.logits_processor(self.language_model.embed_tokens,
                                        hidden_states, sampling_metadata)
                                        hidden_states, sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/persimmon.py

@@ -286,8 +286,11 @@ class PersimmonForCausalLM(nn.Module):
         )
         )
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/phi.py

@@ -283,8 +283,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
 
 
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata, self.lm_head.bias)
                                        sampling_metadata, self.lm_head.bias)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/phi3_small.py

@@ -399,8 +399,11 @@ class Phi3SmallForCausalLM(nn.Module):
     def get_decoder(self):
     def get_decoder(self):
         return self.model
         return self.model
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         if self.dummy_token_indices is not None and logits is not None:
         if self.dummy_token_indices is not None and logits is not None:

+ 5 - 2
aphrodite/modeling/models/phi3v.py

@@ -580,8 +580,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
 
 
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/qwen.py

@@ -252,8 +252,11 @@ class QWenLMHeadModel(nn.Module):
                                          attn_metadata)
                                          attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/qwen2.py

@@ -357,8 +357,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata, intermediate_tensors)
                                    attn_metadata, intermediate_tensors)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/qwen2_moe.py

@@ -399,8 +399,11 @@ class Qwen2MoeForCausalLM(nn.Module):
                                    attn_metadata, intermediate_tensors)
                                    attn_metadata, intermediate_tensors)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/stablelm.py

@@ -258,8 +258,11 @@ class StablelmForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/starcoder2.py

@@ -268,8 +268,11 @@ class Starcoder2ForCausalLM(nn.Module):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits

+ 5 - 2
aphrodite/modeling/models/xverse.py

@@ -325,8 +325,11 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata)
                                    attn_metadata)
         return hidden_states
         return hidden_states
 
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
                                        sampling_metadata)
         return logits
         return logits