Pārlūkot izejas kodu

fix: loading chameleon model with TP>1 (#695)

AlpinDale 6 mēneši atpakaļ
vecāks
revīzija
0e558e9b2f
55 mainītis faili ar 313 papildinājumiem un 134 dzēšanām
  1. 5 3
      aphrodite/common/outputs.py
  2. 1 1
      aphrodite/distributed/communication_op.py
  3. 1 1
      aphrodite/distributed/parallel_state.py
  4. 8 4
      aphrodite/modeling/layers/logits_processor.py
  5. 5 2
      aphrodite/modeling/model_loader/neuron.py
  6. 5 2
      aphrodite/modeling/model_loader/openvino.py
  7. 24 2
      aphrodite/modeling/model_loader/weight_utils.py
  8. 5 2
      aphrodite/modeling/models/arctic.py
  9. 5 2
      aphrodite/modeling/models/baichuan.py
  10. 5 2
      aphrodite/modeling/models/bart.py
  11. 5 2
      aphrodite/modeling/models/blip2.py
  12. 5 2
      aphrodite/modeling/models/bloom.py
  13. 18 5
      aphrodite/modeling/models/chameleon.py
  14. 5 2
      aphrodite/modeling/models/chatglm.py
  15. 10 19
      aphrodite/modeling/models/commandr.py
  16. 5 2
      aphrodite/modeling/models/dbrx.py
  17. 5 2
      aphrodite/modeling/models/deepseek.py
  18. 5 2
      aphrodite/modeling/models/deepseek_v2.py
  19. 5 2
      aphrodite/modeling/models/falcon.py
  20. 5 2
      aphrodite/modeling/models/fuyu.py
  21. 5 2
      aphrodite/modeling/models/gemma.py
  22. 5 2
      aphrodite/modeling/models/gemma2.py
  23. 5 2
      aphrodite/modeling/models/gpt2.py
  24. 5 2
      aphrodite/modeling/models/gpt_bigcode.py
  25. 5 2
      aphrodite/modeling/models/gpt_j.py
  26. 5 2
      aphrodite/modeling/models/gpt_neox.py
  27. 5 2
      aphrodite/modeling/models/internlm2.py
  28. 5 2
      aphrodite/modeling/models/internvl.py
  29. 5 2
      aphrodite/modeling/models/jais.py
  30. 5 2
      aphrodite/modeling/models/jamba.py
  31. 5 2
      aphrodite/modeling/models/llama.py
  32. 5 2
      aphrodite/modeling/models/llava.py
  33. 5 2
      aphrodite/modeling/models/llava_next.py
  34. 5 2
      aphrodite/modeling/models/mamba.py
  35. 11 5
      aphrodite/modeling/models/medusa.py
  36. 5 2
      aphrodite/modeling/models/minicpm.py
  37. 5 2
      aphrodite/modeling/models/minicpmv.py
  38. 5 2
      aphrodite/modeling/models/mixtral.py
  39. 5 2
      aphrodite/modeling/models/mixtral_quant.py
  40. 5 2
      aphrodite/modeling/models/mpt.py
  41. 5 2
      aphrodite/modeling/models/nemotron.py
  42. 5 2
      aphrodite/modeling/models/olmo.py
  43. 5 2
      aphrodite/modeling/models/opt.py
  44. 5 2
      aphrodite/modeling/models/orion.py
  45. 5 2
      aphrodite/modeling/models/paligemma.py
  46. 5 2
      aphrodite/modeling/models/persimmon.py
  47. 5 2
      aphrodite/modeling/models/phi.py
  48. 5 2
      aphrodite/modeling/models/phi3_small.py
  49. 5 2
      aphrodite/modeling/models/phi3v.py
  50. 5 2
      aphrodite/modeling/models/qwen.py
  51. 5 2
      aphrodite/modeling/models/qwen2.py
  52. 5 2
      aphrodite/modeling/models/qwen2_moe.py
  53. 5 2
      aphrodite/modeling/models/stablelm.py
  54. 5 2
      aphrodite/modeling/models/starcoder2.py
  55. 5 2
      aphrodite/modeling/models/xverse.py

+ 5 - 3
aphrodite/common/outputs.py

@@ -1,6 +1,8 @@
 import time
 from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional
+from typing import Sequence as GenericSequence
+from typing import Union
 
 from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
                                        SampleLogprobs, SequenceGroup,
@@ -29,7 +31,7 @@ class CompletionOutput:
 
     index: int
     text: str
-    token_ids: Tuple[int, ...]
+    token_ids: GenericSequence[int]
     cumulative_logprob: Optional[float]
     logprobs: Optional[SampleLogprobs]
     finish_reason: Optional[str] = None
@@ -139,7 +141,7 @@ class RequestOutput:
             CompletionOutput(
                 seqs.index(seq),
                 seq.get_output_text_to_return(text_buffer_length),
-                seq.data._output_token_ids, # type: ignore
+                seq.data._output_token_ids,
                 seq.get_cumulative_logprob() if include_logprobs else None,
                 seq.output_logprobs if include_logprobs else None,
                 SequenceStatus.get_finished_reason(seq.status),

+ 1 - 1
aphrodite/distributed/communication_op.py

@@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
 
 def tensor_model_parallel_gather(input_: torch.Tensor,
                                  dst: int = 0,
-                                 dim: int = -1) -> torch.Tensor:
+                                 dim: int = -1) -> Optional[torch.Tensor]:
     """Gather the input tensor across model parallel group."""
     return get_tp_group().gather(input_, dst, dim)
 

+ 1 - 1
aphrodite/distributed/parallel_state.py

@@ -329,7 +329,7 @@ class GroupCoordinator:
     def gather(self,
                input_: torch.Tensor,
                dst: int = 0,
-               dim: int = -1) -> torch.Tensor:
+               dim: int = -1) -> Optional[torch.Tensor]:
         """
         NOTE: We assume that the input tensor is on the same device across
         all the ranks.

+ 8 - 4
aphrodite/modeling/layers/logits_processor.py

@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
         embedding_bias: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
+    ) -> Optional[torch.Tensor]:
         if self.logits_as_input:
             logits = hidden_states
         else:
@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
 
         return logits
 
-    def _get_logits(self, hidden_states: torch.Tensor,
-                    lm_head: VocabParallelEmbedding,
-                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
+    def _get_logits(
+        self,
+        hidden_states: torch.Tensor,
+        lm_head: VocabParallelEmbedding,
+        embedding_bias: Optional[torch.Tensor],
+    ) -> Optional[torch.Tensor]:
         # Get the logits for the next tokens.
         logits = lm_head.linear_method.apply(lm_head,
                                              hidden_states,
                                              bias=embedding_bias)
         if self.use_gather:
+            # None may be returned for rank > 0
             logits = tensor_model_parallel_gather(logits)
         else:
             # Gather is not supported for some devices such as TPUs.

+ 5 - 2
aphrodite/modeling/model_loader/neuron.py

@@ -62,8 +62,11 @@ class NeuronCasualLM(nn.Module):
                             start_ids=input_block_ids)
         return logits
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(None, hidden_states, sampling_metadata)
         return logits
 

+ 5 - 2
aphrodite/modeling/model_loader/openvino.py

@@ -179,8 +179,11 @@ class OpenVINOCasualLM(nn.Module):
         # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension
         return logits.view(-1, logits.shape[-1])
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
         logits = self.logits_processor(None, hidden_states, sampling_metadata)
         return logits

+ 24 - 2
aphrodite/modeling/model_loader/weight_utils.py

@@ -501,8 +501,30 @@ def default_weight_loader(param: torch.Tensor,
                           loaded_weight: torch.Tensor) -> None:
     """Default weight loader."""
 
-    assert param.size() == loaded_weight.size()
-    param.data.copy_(loaded_weight)
+    try:
+        assert param.size() == loaded_weight.size(), (
+            f"Attempted to load weight ({loaded_weight.size()}) "
+            f"into parameter ({param.size()})")
+
+        param.data.copy_(loaded_weight)
+    except Exception:
+        # NOTE: This exception is added for the purpose of setting breakpoint to
+        # debug weight loading issues.
+        raise
+
+
+def row_parallel_weight_loader(param: torch.Tensor,
+                               loaded_weight: torch.Tensor) -> None:
+    """Load weights that are row-parallelized."""
+    tp_rank = get_tensor_model_parallel_rank()
+    shard_dim = 0 if param.dim() != 1 else None
+
+    if shard_dim is not None:
+        shard_size = param.data.shape[shard_dim]
+        start_idx = tp_rank * shard_size
+        loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
+
+    return default_weight_loader(param, loaded_weight)
 
 
 def initialize_dummy_weights(

+ 5 - 2
aphrodite/modeling/models/arctic.py

@@ -431,8 +431,11 @@ class ArcticForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/baichuan.py

@@ -344,8 +344,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/bart.py

@@ -868,8 +868,11 @@ class BartForConditionalGeneration(nn.Module):
         return self.model(input_ids, positions, encoder_input_ids,
                           encoder_positions, kv_caches, attn_metadata)
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/blip2.py

@@ -631,8 +631,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
 
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.get_lm_head(), hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/bloom.py

@@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 18 - 5
aphrodite/modeling/models/chameleon.py

@@ -25,8 +25,10 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, row_parallel_weight_loader)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal.image import (cached_get_tokenizer,
                                         repeat_and_pad_image_tokens)
@@ -138,6 +140,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
         super().__init__(hidden_size, *args, **kwargs)
         self.normalized_shape = (hidden_size[-1], )
 
+        set_weight_attrs(self.weight,
+                         {"weight_loader": row_parallel_weight_loader})
+        set_weight_attrs(self.bias,
+                         {"weight_loader": row_parallel_weight_loader})
+
     def forward(self, hidden_states):
         hidden_states = F.layer_norm(hidden_states,
                                      self.normalized_shape,
@@ -694,6 +701,8 @@ class ChameleonVQVAEEncoder(nn.Module):
         )
 
     def forward(self, pixel_values: torch.Tensor):
+        pixel_values = pixel_values.to(self.conv_in.weight.dtype)
+
         # downsampling
         hidden_states = [self.conv_in(pixel_values)]
         for i_level in range(self.num_resolutions):
@@ -956,15 +965,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
 
         # Disallow image tokens which does not include special
         # begin-image and end-image tokens
-        image_tokens = self.model.vocabulary_mapping.image_tokens
-        logits[:, image_tokens] = torch.finfo(logits.dtype).min
+        if logits is not None:
+            image_tokens = self.model.vocabulary_mapping.image_tokens
+            logits[:, image_tokens] = torch.finfo(logits.dtype).min
 
         return logits
 

+ 5 - 2
aphrodite/modeling/models/chatglm.py

@@ -370,8 +370,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 10 - 19
aphrodite/modeling/models/commandr.py

@@ -25,15 +25,13 @@ from typing import Iterable, List, Optional, Tuple
 import torch
 import torch.utils.checkpoint
 from torch import nn
-from torch.nn.parameter import Parameter
 from transformers import CohereConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.utils import progress_bar
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size)
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
@@ -43,7 +41,8 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
-from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, row_parallel_weight_loader)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
         super().__init__()
         self.weight = nn.Parameter(torch.ones(param_shape))
         self.variance_epsilon = eps
-        set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
+        set_weight_attrs(self.weight,
+                         {"weight_loader": row_parallel_weight_loader})
 
     def forward(self, hidden_states, residuals=None):
         hidden_states = layer_norm_func(hidden_states, self.weight,
                                         self.variance_epsilon)
         return hidden_states, residuals
 
-    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
-        tp_rank = get_tensor_model_parallel_rank()
-        shard_dim = 0 if param.dim() != 1 else None
-        param_data = param.data
-        if shard_dim is not None:
-            shard_size = param_data.shape[shard_dim]
-            start_idx = tp_rank * shard_size
-            loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
-                                                 shard_size)
-        assert param_data.shape == loaded_weight.shape
-        param_data.copy_(loaded_weight)
-
 
 # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
 class CohereMLP(nn.Module):
@@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         is_not_lora = hasattr(self.model.embed_tokens, 'weight')
         if is_not_lora:
             logits = self.logits_processor(self.model.embed_tokens,

+ 5 - 2
aphrodite/modeling/models/dbrx.py

@@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/deepseek.py

@@ -395,8 +395,11 @@ class DeepseekForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/deepseek_v2.py

@@ -456,8 +456,11 @@ class DeepseekV2ForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/falcon.py

@@ -395,8 +395,11 @@ class FalconForCausalLM(nn.Module):
         )
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/fuyu.py

@@ -285,8 +285,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
         )
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.language_model.logits_processor(
             self.language_model.lm_head, hidden_states, sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/gemma.py

@@ -350,8 +350,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.model.embed_tokens, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/gemma2.py

@@ -342,8 +342,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.model.embed_tokens, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/gpt2.py

@@ -234,8 +234,11 @@ class GPT2LMHeadModel(nn.Module):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/gpt_bigcode.py

@@ -253,8 +253,11 @@ class GPTBigCodeForCausalLM(nn.Module):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/gpt_j.py

@@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata, self.lm_head.bias)
         return logits

+ 5 - 2
aphrodite/modeling/models/gpt_neox.py

@@ -258,8 +258,11 @@ class GPTNeoXForCausalLM(nn.Module):
                                       attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.embed_out, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/internlm2.py

@@ -279,8 +279,11 @@ class InternLM2ForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.output, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/internvl.py

@@ -461,8 +461,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
                                                   inputs_embeds=inputs_embeds)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         return self.language_model.compute_logits(hidden_states,
                                                   sampling_metadata)
 

+ 5 - 2
aphrodite/modeling/models/jais.py

@@ -295,8 +295,11 @@ class JAISLMHeadModel(nn.Module):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/jamba.py

@@ -684,8 +684,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
         )
         return conv_state_shape, temporal_state_shape
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/llama.py

@@ -425,8 +425,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
                                   attn_metadata, intermediate_tensors)
         return model_output
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/llava.py

@@ -353,8 +353,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         return self.language_model.compute_logits(hidden_states,
                                                   sampling_metadata)
 

+ 5 - 2
aphrodite/modeling/models/llava_next.py

@@ -583,8 +583,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         return self.language_model.compute_logits(hidden_states,
                                                   sampling_metadata)
 

+ 5 - 2
aphrodite/modeling/models/mamba.py

@@ -490,8 +490,11 @@ class MambaForCausalLM(nn.Module, HasInnerState):
         )
         return conv_state_shape, temporal_state_shape
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 11 - 5
aphrodite/modeling/models/medusa.py

@@ -66,22 +66,28 @@ class Medusa(nn.Module):
     def compute_logits(
             self, hidden_states: List[torch.Tensor],
             sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
-        logits = []
+        logits_lst: List[torch.Tensor] = []
 
         for hs, lm_head in zip(hidden_states, self.lm_heads):
             _logits = self.logits_processor(lm_head, hs, sampling_metadata)
 
+            if _logits is None:
+                # _logits should only be None on rank > 0, in which case
+                # it should remain true for every lm_head
+                assert len(logits_lst) == 0
+                continue
+
             if self.token_map is None:
-                logits.append(_logits)
+                logits_lst.append(_logits)
             else:
-                logits.append(-torch.inf * torch.ones(
+                logits_lst.append(-torch.inf * torch.ones(
                     size=(*_logits.shape[:-1], self.orig_vocab_size),
                     device=_logits.device,
                     dtype=_logits.dtype))
 
-                logits[-1][..., self.token_map] = _logits
+                logits_lst[-1][..., self.token_map] = _logits
 
-        return logits
+        return logits_lst
 
     def sample(
         self,

+ 5 - 2
aphrodite/modeling/models/minicpm.py

@@ -468,8 +468,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata, intermediate_tensors)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         hidden_states = hidden_states / self.scale_width
         if self.config.tie_word_embeddings:
             lm_head = self.model.embed_tokens

+ 5 - 2
aphrodite/modeling/models/minicpmv.py

@@ -631,8 +631,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
         )
         return output
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/mixtral.py

@@ -375,8 +375,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata, intermediate_tensors)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/mixtral_quant.py

@@ -362,8 +362,11 @@ class MixtralForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/mpt.py

@@ -279,8 +279,11 @@ class MPTForCausalLM(nn.Module):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/nemotron.py

@@ -455,8 +455,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
                                   attn_metadata, intermediate_tensors)
         return model_output
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/olmo.py

@@ -311,8 +311,11 @@ class OlmoForCausalLM(nn.Module):
         )
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/opt.py

@@ -323,8 +323,11 @@ class OPTForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/orion.py

@@ -277,8 +277,11 @@ class OrionForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/paligemma.py

@@ -258,8 +258,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
         return hidden_states
 
     # Copied from vllm/modeling/models/gemma.py
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.language_model.embed_tokens,
                                        hidden_states, sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/persimmon.py

@@ -286,8 +286,11 @@ class PersimmonForCausalLM(nn.Module):
         )
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/phi.py

@@ -283,8 +283,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
 
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata, self.lm_head.bias)
         return logits

+ 5 - 2
aphrodite/modeling/models/phi3_small.py

@@ -399,8 +399,11 @@ class Phi3SmallForCausalLM(nn.Module):
     def get_decoder(self):
         return self.model
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         if self.dummy_token_indices is not None and logits is not None:

+ 5 - 2
aphrodite/modeling/models/phi3v.py

@@ -580,8 +580,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
 
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/qwen.py

@@ -252,8 +252,11 @@ class QWenLMHeadModel(nn.Module):
                                          attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/qwen2.py

@@ -357,8 +357,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata, intermediate_tensors)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/qwen2_moe.py

@@ -399,8 +399,11 @@ class Qwen2MoeForCausalLM(nn.Module):
                                    attn_metadata, intermediate_tensors)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/stablelm.py

@@ -258,8 +258,11 @@ class StablelmForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/starcoder2.py

@@ -268,8 +268,11 @@ class Starcoder2ForCausalLM(nn.Module):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits

+ 5 - 2
aphrodite/modeling/models/xverse.py

@@ -325,8 +325,11 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
                                    attn_metadata)
         return hidden_states
 
-    def compute_logits(self, hidden_states: torch.Tensor,
-                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
         logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits