Jelajahi Sumber

quantized lm_head (#582)

* wip

* wip
AlpinDale 6 bulan lalu
induk
melakukan
0f4a9ee77b
43 mengubah file dengan 213 tambahan dan 111 penghapusan
  1. 2 2
      aphrodite/lora/layers.py
  2. 9 6
      aphrodite/modeling/layers/logits_processor.py
  3. 52 17
      aphrodite/modeling/layers/vocab_parallel_embedding.py
  4. 2 1
      aphrodite/modeling/models/arctic.py
  5. 4 2
      aphrodite/modeling/models/baichuan.py
  6. 2 2
      aphrodite/modeling/models/bloom.py
  7. 4 3
      aphrodite/modeling/models/chatglm.py
  8. 4 4
      aphrodite/modeling/models/commandr.py
  9. 2 1
      aphrodite/modeling/models/dbrx.py
  10. 4 2
      aphrodite/modeling/models/deepseek.py
  11. 4 2
      aphrodite/modeling/models/deepseek_v2.py
  12. 2 2
      aphrodite/modeling/models/falcon.py
  13. 2 2
      aphrodite/modeling/models/gemma.py
  14. 2 2
      aphrodite/modeling/models/gpt2.py
  15. 2 2
      aphrodite/modeling/models/gpt_bigcode.py
  16. 2 1
      aphrodite/modeling/models/gpt_j.py
  17. 2 1
      aphrodite/modeling/models/gpt_neox.py
  18. 4 2
      aphrodite/modeling/models/internlm2.py
  19. 2 2
      aphrodite/modeling/models/jais.py
  20. 2 1
      aphrodite/modeling/models/llama.py
  21. 3 2
      aphrodite/modeling/models/llava.py
  22. 3 2
      aphrodite/modeling/models/llava_next.py
  23. 4 3
      aphrodite/modeling/models/minicpm.py
  24. 2 1
      aphrodite/modeling/models/mixtral.py
  25. 4 2
      aphrodite/modeling/models/mixtral_quant.py
  26. 5 5
      aphrodite/modeling/models/mlp_speculator.py
  27. 2 2
      aphrodite/modeling/models/mpt.py
  28. 3 3
      aphrodite/modeling/models/olmo.py
  29. 2 2
      aphrodite/modeling/models/opt.py
  30. 4 2
      aphrodite/modeling/models/orion.py
  31. 3 2
      aphrodite/modeling/models/phi.py
  32. 2 1
      aphrodite/modeling/models/phi3_small.py
  33. 4 2
      aphrodite/modeling/models/phi3v.py
  34. 4 2
      aphrodite/modeling/models/qwen.py
  35. 4 4
      aphrodite/modeling/models/qwen2.py
  36. 4 2
      aphrodite/modeling/models/qwen2_moe.py
  37. 4 2
      aphrodite/modeling/models/stablelm.py
  38. 3 3
      aphrodite/modeling/models/starcoder2.py
  39. 4 2
      aphrodite/modeling/models/xverse.py
  40. 9 0
      aphrodite/quantization/base_config.py
  41. 10 3
      aphrodite/quantization/gptq.py
  42. 11 4
      aphrodite/quantization/gptq_marlin.py
  43. 10 3
      aphrodite/quantization/marlin.py

+ 2 - 2
aphrodite/lora/layers.py

@@ -1167,11 +1167,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
     def _get_logits(
         self,
         hidden_states: torch.Tensor,
-        embedding: torch.Tensor,
+        lm_head: VocabParallelEmbedding,
         embedding_bias: Optional[torch.Tensor] = None,
     ) -> Optional[torch.Tensor]:
         # Get the logits for the next tokens.
-        logits = torch.matmul(hidden_states, embedding.t())
+        logits = lm_head.linear_method.apply(lm_head, hidden_states)
         if embedding_bias is not None:
             logits += embedding_bias
         logits = tensor_model_parallel_gather(logits)

+ 9 - 6
aphrodite/modeling/layers/logits_processor.py

@@ -6,6 +6,8 @@ import torch
 import torch.nn as nn
 
 from aphrodite.distributed import tensor_model_parallel_gather
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 
 
@@ -37,7 +39,7 @@ class LogitsProcessor(nn.Module):
 
     def forward(
         self,
-        embedding: torch.Tensor,
+        lm_head: VocabParallelEmbedding,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
         embedding_bias: Optional[torch.Tensor] = None,
@@ -49,7 +51,7 @@ class LogitsProcessor(nn.Module):
                                                  sampling_metadata)
 
             # Get the logits for the next tokens.
-            logits = self._get_logits(hidden_states, embedding, embedding_bias)
+            logits = self._get_logits(hidden_states, lm_head, embedding_bias)
 
         if logits is not None:
             if self.scale != 1.0:
@@ -60,12 +62,13 @@ class LogitsProcessor(nn.Module):
 
         return logits
 
-    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
+    def _get_logits(self, hidden_states: torch.Tensor,
+                    lm_head: VocabParallelEmbedding,
                     embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
         # Get the logits for the next tokens.
-        logits = torch.matmul(hidden_states, embedding.t())
-        if embedding_bias is not None:
-            logits += embedding_bias
+        logits = lm_head.linear_method.apply(lm_head,
+                                             hidden_states,
+                                             bias=embedding_bias)
         logits = tensor_model_parallel_gather(logits)
         # Remove paddings in vocab (if any).
         if logits is not None:

+ 52 - 17
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -8,7 +8,10 @@ from torch.nn.parameter import Parameter
 from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
+from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
 from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
 
 DEFAULT_VOCAB_PADDING_SIZE = 64
 
@@ -156,6 +159,7 @@ class VocabParallelEmbedding(torch.nn.Module):
         params_dtype: type of the parameters.
         org_num_embeddings: original vocabulary size (without LoRA).
         padding_size: padding size for the vocabulary.
+        quant_config: quant config for the layer.
     """  # noqa: E501
 
     def __init__(self,
@@ -163,7 +167,8 @@ class VocabParallelEmbedding(torch.nn.Module):
                  embedding_dim: int,
                  params_dtype: Optional[torch.dtype] = None,
                  org_num_embeddings: Optional[int] = None,
-                 padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
+                 padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
 
         # Keep the input dimensions.
@@ -186,6 +191,14 @@ class VocabParallelEmbedding(torch.nn.Module):
                                                self.org_vocab_size, tp_rank,
                                                self.tp_size)
         self.embedding_dim = embedding_dim
+
+        linear_method = None
+        if quant_config is not None:
+            linear_method = quant_config.get_quant_method(self)
+        if linear_method is None:
+            linear_method = UnquantizedLinearMethod()
+        self.linear_method: QuantizeMethodBase = linear_method
+
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
         # Divide the weight matrix along the vocaburaly dimension.
@@ -200,14 +213,13 @@ class VocabParallelEmbedding(torch.nn.Module):
         self.num_added_embeddings_per_partition = (
             self.shard_indices.added_vocab_end_index -
             self.shard_indices.added_vocab_start_index)
-        self.weight = Parameter(
-            torch.empty(self.num_embeddings_per_partition,
-                        self.embedding_dim,
-                        dtype=params_dtype))
-        set_weight_attrs(self.weight, {
-            "parallel_dim": 0,
-            "weight_loader": self.weight_loader
-        })
+        self.linear_method.create_weights(self,
+                                          self.embedding_dim,
+                                          [self.num_embeddings_per_partition],
+                                          self.embedding_dim,
+                                          self.num_embeddings_padded,
+                                          params_dtype=params_dtype,
+                                          weight_loader=self.weight_loader)
 
     @classmethod
     def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
@@ -287,10 +299,32 @@ class VocabParallelEmbedding(torch.nn.Module):
         return ret
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
-        parallel_dim = param.parallel_dim
-        assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
-        loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index:
-                                      self.shard_indices.org_vocab_end_index]
+        output_dim = getattr(param, "output_dim", None)
+        packed_dim = getattr(param, "packed_dim", None)
+
+        # If parameter does not have output dim, then it should
+        # be copied onto all gpus (e.g. g_idx for act_order gptq).
+        if output_dim is None:
+            assert param.data.shape == loaded_weight.shape
+            param.data.copy_(loaded_weight)
+            return
+
+        # Shard indexes for loading the weight
+        start_idx = self.shard_indices.org_vocab_start_index
+        shard_size = self.shard_indices.org_vocab_end_index - start_idx
+
+        # If param packed on the same dim we are sharding on, then
+        # need to adjust offsets of loaded weight by pack_factor.
+        if packed_dim is not None and packed_dim == output_dim:
+            assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
+                                                       param.pack_factor)
+            start_idx = start_idx // param.pack_factor
+            shard_size = shard_size // param.pack_factor
+        else:
+            assert loaded_weight.shape[output_dim] == self.org_vocab_size
+
+        # Copy the data.
+        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
         param[loaded_weight.shape[0]:].data.fill_(0)
 
@@ -345,16 +379,17 @@ class ParallelLMHead(VocabParallelEmbedding):
                  bias: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
                  org_num_embeddings: Optional[int] = None,
-                 padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
+                 padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__(num_embeddings, embedding_dim, params_dtype,
-                         org_num_embeddings, padding_size)
+                         org_num_embeddings, padding_size, quant_config)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.num_embeddings_per_partition,
                             dtype=params_dtype))
             set_weight_attrs(self.bias, {
-                "parallel_dim": 0,
-                "weight_loader": self.weight_loader
+                "output_dim": 0,
+                "weight_loader": self.weight_loader,
             })
         else:
             self.register_parameter("bias", None)

+ 2 - 1
aphrodite/modeling/models/arctic.py

@@ -409,6 +409,7 @@ class ArcticForCausalLM(nn.Module):
         self.lm_head = ParallelLMHead(
             self.vocab_size,
             config.hidden_size,
+            quant_config=quant_config,
         )
         self.num_experts = config.num_local_experts
         self.num_experts_per_tok = config.num_experts_per_tok
@@ -431,7 +432,7 @@ class ArcticForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/baichuan.py

@@ -325,7 +325,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
         self.quant_config = quant_config
         self.model = BaiChuanModel(config, position_embedding, cache_config,
                                    quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -343,7 +345,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 2
aphrodite/modeling/models/bloom.py

@@ -275,7 +275,7 @@ class BloomForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.transformer = BloomModel(config, cache_config, quant_config)
-        self.lm_head_weight = self.transformer.word_embeddings.weight
+        self.lm_head = self.transformer.word_embeddings
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -293,7 +293,7 @@ class BloomForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 3
aphrodite/modeling/models/chatglm.py

@@ -301,7 +301,8 @@ class ChatGLMModel(nn.Module):
         self.encoder = GLMTransformer(config, cache_config, quant_config)
 
         self.output_layer = ParallelLMHead(config.padded_vocab_size,
-                                           config.hidden_size)
+                                           config.hidden_size,
+                                           quant_config=quant_config)
 
     def forward(
         self,
@@ -352,7 +353,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
         self.max_position_embeddings = getattr(config, "max_sequence_length",
                                                8192)
         self.transformer = ChatGLMModel(config, cache_config, quant_config)
-        self.lm_head_weight = self.transformer.output_layer.weight
+        self.lm_head = self.transformer.output_layer
         self.logits_processor = LogitsProcessor(config.padded_vocab_size)
         self.sampler = Sampler()
 
@@ -370,7 +371,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 4
aphrodite/modeling/models/commandr.py

@@ -362,12 +362,12 @@ class CohereForCausalLM(nn.Module):
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
         is_not_lora = hasattr(self.model.embed_tokens, 'weight')
         if is_not_lora:
-            embedding_weights = self.model.embed_tokens.weight
+            logits = self.logits_processor(self.model.embed_tokens,
+                                           hidden_states, sampling_metadata)
         else:
-            embedding_weights = self.model.embed_tokens.base_layer.weight
+            logits = self.logits_processor(self.model.embed_tokens.base_layer,
+                                           hidden_states, sampling_metadata)
 
-        logits = self.logits_processor(embedding_weights, hidden_states,
-                                       sampling_metadata)
         return logits
 
     def sample(

+ 2 - 1
aphrodite/modeling/models/dbrx.py

@@ -369,6 +369,7 @@ class DbrxForCausalLM(nn.Module):
             config.d_model,
             org_num_embeddings=config.vocab_size,
             padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+            quant_config=quant_config,
         )
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size)
@@ -388,7 +389,7 @@ class DbrxForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/deepseek.py

@@ -376,7 +376,9 @@ class DeepseekForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.model = DeepseekModel(config, cache_config, quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -394,7 +396,7 @@ class DeepseekForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/deepseek_v2.py

@@ -465,7 +465,9 @@ class DeepseekV2ForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.model = DeepseekV2Model(config, cache_config, quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -483,7 +485,7 @@ class DeepseekV2ForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 2
aphrodite/modeling/models/falcon.py

@@ -374,7 +374,7 @@ class FalconForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.transformer = FalconModel(config, cache_config, quant_config)
-        self.lm_head_weight = self.transformer.word_embeddings.weight
+        self.lm_head = self.transformer.word_embeddings
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -396,7 +396,7 @@ class FalconForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 2
aphrodite/modeling/models/gemma.py

@@ -340,8 +340,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.model.embed_tokens.weight,
-                                       hidden_states, sampling_metadata)
+        logits = self.logits_processor(self.model.embed_tokens, hidden_states,
+                                       sampling_metadata)
         return logits
 
     def sample(

+ 2 - 2
aphrodite/modeling/models/gpt2.py

@@ -217,7 +217,7 @@ class GPT2LMHeadModel(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.transformer = GPT2Model(config, cache_config, quant_config)
-        self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = self.transformer.wte
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -235,7 +235,7 @@ class GPT2LMHeadModel(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 2
aphrodite/modeling/models/gpt_bigcode.py

@@ -236,7 +236,7 @@ class GPTBigCodeForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.transformer = GPTBigCodeModel(config, cache_config, quant_config)
-        self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = self.transformer.wte
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -254,7 +254,7 @@ class GPTBigCodeForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 1
aphrodite/modeling/models/gpt_j.py

@@ -228,6 +228,7 @@ class GPTJForCausalLM(nn.Module):
             config.vocab_size,
             config.n_embd,
             bias=True,
+            quant_config=quant_config,
         )
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
@@ -246,7 +247,7 @@ class GPTJForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata, self.lm_head.bias)
         return logits
 

+ 2 - 1
aphrodite/modeling/models/gpt_neox.py

@@ -240,6 +240,7 @@ class GPTNeoXForCausalLM(nn.Module):
         self.embed_out = ParallelLMHead(
             config.vocab_size,
             config.hidden_size,
+            quant_config=quant_config,
         )
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
@@ -258,7 +259,7 @@ class GPTNeoXForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.embed_out.weight, hidden_states,
+        logits = self.logits_processor(self.embed_out, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/internlm2.py

@@ -252,7 +252,9 @@ class InternLM2ForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.model = InternLM2Model(config, cache_config, quant_config)
-        self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.output = ParallelLMHead(config.vocab_size,
+                                     config.hidden_size,
+                                     quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -270,7 +272,7 @@ class InternLM2ForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.output.weight, hidden_states,
+        logits = self.logits_processor(self.output, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 2
aphrodite/modeling/models/jais.py

@@ -272,7 +272,7 @@ class JAISLMHeadModel(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.transformer = JAISModel(config, cache_config, quant_config)
-        self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = self.transformer.wte
         if hasattr(config, "width_scale"):
             self.output_logits_scale = config.width_scale
         else:
@@ -296,7 +296,7 @@ class JAISLMHeadModel(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 1
aphrodite/modeling/models/llama.py

@@ -379,6 +379,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
             # We need bigger padding if using lora for kernel
             # compatibility
             if not lora_config else lora_config.lora_vocab_padding_size,
+            quant_config=quant_config,
         )
         if config.tie_word_embeddings:
             self.lm_head.weight = self.model.embed_tokens.weight
@@ -402,7 +403,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 3 - 2
aphrodite/modeling/models/llava.py

@@ -125,7 +125,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
         self.lm_head = ParallelLMHead(
             self.unpadded_vocab_size,
             config.text_config.hidden_size,
-            org_num_embeddings=self.language_model.org_vocab_size)
+            org_num_embeddings=self.language_model.org_vocab_size,
+            quant_config=quant_config)
         logit_scale = getattr(config, "logit_scale", 1.0)
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size, logit_scale)
@@ -255,7 +256,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 3 - 2
aphrodite/modeling/models/llava_next.py

@@ -183,7 +183,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         self.lm_head = ParallelLMHead(
             self.unpadded_vocab_size,
             config.text_config.hidden_size,
-            org_num_embeddings=self.language_model.org_vocab_size)
+            org_num_embeddings=self.language_model.org_vocab_size,
+            quant_config=quant_config)
         logit_scale = getattr(config, "logit_scale", 1.0)
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size, logit_scale)
@@ -435,7 +436,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 3
aphrodite/modeling/models/minicpm.py

@@ -445,6 +445,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
                 # We need bigger padding if using lora for kernel
                 # compatibility
                 if not lora_config else lora_config.lora_vocab_padding_size,
+                quant_config=quant_config,
             )
         self.scale_width = self.config.hidden_size / self.config.dim_model_base
 
@@ -468,10 +469,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
         hidden_states = hidden_states / self.scale_width
         if self.config.tie_word_embeddings:
-            lm_head_weight = self.model.embed_tokens.weight
+            lm_head = self.model.embed_tokens
         else:
-            lm_head_weight = self.lm_head.weight
-        logits = self.logits_processor(lm_head_weight, hidden_states,
+            lm_head = self.lm_head
+        logits = self.logits_processor(lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 1
aphrodite/modeling/models/mixtral.py

@@ -330,6 +330,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
             # We need bigger padding if using lora for kernel
             # compatibility
             if not lora_config else lora_config.lora_vocab_padding_size,
+            quant_config=quant_config,
         )
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size)
@@ -349,7 +350,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/mixtral_quant.py

@@ -343,7 +343,9 @@ class MixtralForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.model = MixtralModel(config, cache_config, quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -361,7 +363,7 @@ class MixtralForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 5 - 5
aphrodite/modeling/models/mlp_speculator.py

@@ -8,8 +8,8 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling import SamplingMetadata
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import \
-    VocabParallelEmbedding
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
 
@@ -87,7 +87,7 @@ class MLPSpeculator(nn.Module):
             self.proj = nn.ModuleList([proj_first] + [proj_tied] *
                                       (self.max_speculative_tokens - 1))
 
-            head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
+            head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
             self.head = nn.ModuleList([head] * self.max_speculative_tokens)
 
             ln = MLPSpeculatorLayerNorm(self.inner_dim,
@@ -169,8 +169,8 @@ class MLPSpeculator(nn.Module):
             # TODO: not yet supporting top_k_tokens_per_head
             previous_hidden_states = states
 
-            logits = self.logits_processor(self.head[head_index].weight,
-                                           states, sampling_metadata)
+            logits = self.logits_processor(self.head[head_index], states,
+                                           sampling_metadata)
 
             output = self.sampler(logits.flatten(0, 1), sampling_metadata)
             last_tokens = output.sampled_token_ids

+ 2 - 2
aphrodite/modeling/models/mpt.py

@@ -262,7 +262,7 @@ class MPTForCausalLM(nn.Module):
         self.quant_config = quant_config
 
         self.transformer = MPTModel(config, cache_config, quant_config)
-        self.lm_head_weight = self.transformer.wte.weight
+        self.lm_head = self.transformer.wte
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -280,7 +280,7 @@ class MPTForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 3 - 3
aphrodite/modeling/models/olmo.py

@@ -282,15 +282,15 @@ class OlmoForCausalLM(nn.Module):
         self.config = config
         self.model = OlmoModel(config, cache_config, quant_config)
         if config.tie_word_embeddings:
-            self.lm_head_weight = self.model.embed_tokens.weight
+            self.lm_head = self.model.embed_tokens
         else:
             self.unpadded_vocab_size = config.vocab_size
             self.lm_head = ParallelLMHead(
                 self.unpadded_vocab_size,
                 config.hidden_size,
                 org_num_embeddings=config.vocab_size,
+                quant_config=quant_config,
             )
-            self.lm_head_weight = self.lm_head.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -312,7 +312,7 @@ class OlmoForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 2 - 2
aphrodite/modeling/models/opt.py

@@ -293,7 +293,7 @@ class OPTForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.model = OPTModel(config, cache_config, quant_config)
-        self.lm_head_weight = self.model.decoder.embed_tokens.weight
+        self.lm_head = self.model.decoder.embed_tokens
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -311,7 +311,7 @@ class OPTForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/orion.py

@@ -258,7 +258,9 @@ class OrionForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.model = OrionModel(config, cache_config, quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -276,7 +278,7 @@ class OrionForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 3 - 2
aphrodite/modeling/models/phi.py

@@ -264,7 +264,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
 
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,
-                                      bias=True)
+                                      bias=True,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -283,7 +284,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata, self.lm_head.bias)
         return logits
 

+ 2 - 1
aphrodite/modeling/models/phi3_small.py

@@ -365,6 +365,7 @@ class Phi3SmallForCausalLM(nn.Module):
             config.hidden_size,
             org_num_embeddings=config.vocab_size,
             padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+            quant_config=quant_config,
         )
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
@@ -399,7 +400,7 @@ class Phi3SmallForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         if self.dummy_token_indices is not None and logits is not None:
             logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)

+ 4 - 2
aphrodite/modeling/models/phi3v.py

@@ -362,7 +362,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
         self.model = LlamaModel(config, cache_config, quant_config)
         self.vision_embed_tokens = Phi3HDImageEmbedding(
             vlm_config, config, self.model.embed_tokens)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -406,7 +408,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/qwen.py

@@ -233,7 +233,9 @@ class QWenLMHeadModel(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.transformer = QWenModel(config, cache_config, quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -251,7 +253,7 @@ class QWenLMHeadModel(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 4
aphrodite/modeling/models/qwen2.py

@@ -311,11 +311,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
         self.model = Qwen2Model(config, cache_config, quant_config)
 
         if config.tie_word_embeddings:
-            self.lm_head_weight = self.model.embed_tokens.weight
+            self.lm_head = self.model.embed_tokens
         else:
             self.lm_head = ParallelLMHead(config.vocab_size,
-                                          config.hidden_size)
-            self.lm_head_weight = self.lm_head.weight
+                                          config.hidden_size,
+                                          quant_config=quant_config)
 
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
@@ -334,7 +334,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/qwen2_moe.py

@@ -361,7 +361,9 @@ class Qwen2MoeForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.model = Qwen2MoeModel(config, cache_config, quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -379,7 +381,7 @@ class Qwen2MoeForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/stablelm.py

@@ -239,7 +239,9 @@ class StablelmForCausalLM(nn.Module):
         self.config = config
         self.quant_config = quant_config
         self.model = StableLMEpochModel(config, cache_config, quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -257,7 +259,7 @@ class StablelmForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 3 - 3
aphrodite/modeling/models/starcoder2.py

@@ -241,7 +241,7 @@ class Starcoder2ForCausalLM(nn.Module):
         self.vocab_size = config.vocab_size
         self.unpadded_vocab_size = config.vocab_size
         if config.tie_word_embeddings:
-            self.lm_head_weight = self.model.embed_tokens.weight
+            self.lm_head = self.model.embed_tokens
         else:
             self.unpadded_vocab_size = config.vocab_size
             self.lm_head = ParallelLMHead(
@@ -249,8 +249,8 @@ class Starcoder2ForCausalLM(nn.Module):
                 config.hidden_size,
                 org_num_embeddings=config.vocab_size,
                 padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+                quant_config=quant_config,
             )
-            self.lm_head_weight = self.lm_head.weight
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size)
         self.sampler = Sampler()
@@ -269,7 +269,7 @@ class Starcoder2ForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 4 - 2
aphrodite/modeling/models/xverse.py

@@ -306,7 +306,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
         self.lora_config = lora_config
         self.quant_config = quant_config
         self.model = XverseModel(config, cache_config, quant_config)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -324,7 +326,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+        logits = self.logits_processor(self.lm_head, hidden_states,
                                        sampling_metadata)
         return logits
 

+ 9 - 0
aphrodite/quantization/base_config.py

@@ -84,6 +84,15 @@ class QuantizationConfig(ABC):
         raise ValueError(f"Cannot find any of {keys} in the model's "
                          "quantization config.")
 
+    @staticmethod
+    def get_from_keys_or(config: Dict[str, Any], keys: List[str],
+                         default: Any) -> Any:
+        """Get a optional value from the model's quantization config."""
+        try:
+            return QuantizationConfig.get_from_keys(config, keys)
+        except ValueError:
+            return default
+
     @abstractmethod
     def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
         """Get the quantize method to use for the quantized layer."""

+ 10 - 3
aphrodite/quantization/gptq.py

@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
 
 from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
@@ -23,10 +24,12 @@ class GPTQConfig(QuantizationConfig):
         weight_bits: int,
         group_size: int,
         desc_act: bool,
+        lm_head_quantized: bool,
     ) -> None:
         self.weight_bits = weight_bits
         self.group_size = group_size
         self.desc_act = desc_act
+        self.lm_head_quantized = lm_head_quantized
         self.pack_factor = Fraction(32, self.weight_bits)
         if self.weight_bits not in [2, 3, 4, 8]:
             raise ValueError(
@@ -36,7 +39,8 @@ class GPTQConfig(QuantizationConfig):
     def __repr__(self) -> str:
         return (f"GPTQConfig(weight_bits={self.weight_bits}, "
                 f"group_size={self.group_size}, "
-                f"desc_act={self.desc_act})")
+                f"desc_act={self.desc_act}),"
+                f"lm_head_quantized={self.lm_head_quantized}")
 
     @classmethod
     def get_name(cls) -> str:
@@ -60,11 +64,14 @@ class GPTQConfig(QuantizationConfig):
         weight_bits = cls.get_from_keys(config, ["bits"])
         group_size = cls.get_from_keys(config, ["group_size"])
         desc_act = cls.get_from_keys(config, ["desc_act"])
-        return cls(weight_bits, group_size, desc_act)
+        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
+                                                 default=False)
+        return cls(weight_bits, group_size, desc_act, lm_head_quantized)
 
     def get_quant_method(
             self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
-        if isinstance(layer, LinearBase):
+        if (isinstance(layer, LinearBase) or
+            (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
             return GPTQLinearMethod(self)
         return None
 

+ 11 - 4
aphrodite/quantization/gptq_marlin.py

@@ -10,6 +10,7 @@ from aphrodite import _custom_ops as ops
 from aphrodite.common.utils import get_device_capability_stateless
 from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
                                               set_weight_attrs)
+from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.quantization.base_config import QuantizationConfig
 
 GPTQ_MARLIN_TILE = 16
@@ -55,7 +56,7 @@ class GPTQMarlinConfig(QuantizationConfig):
     """Config class for GPTQ Marlin"""
 
     def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
-                 is_sym: bool) -> None:
+                 is_sym: bool, lm_head_quantized: bool) -> None:
         if desc_act and group_size == -1:
             # In this case, act_order == True is the same as act_order == False
             # (since we have only one group per output channel)
@@ -65,6 +66,7 @@ class GPTQMarlinConfig(QuantizationConfig):
         self.group_size = group_size
         self.desc_act = desc_act
         self.is_sym = is_sym
+        self.lm_head_quantized = lm_head_quantized
 
         # Verify
         if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
@@ -92,7 +94,8 @@ class GPTQMarlinConfig(QuantizationConfig):
     def __repr__(self) -> str:
         return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
                 f"group_size={self.group_size}, "
-                f"desc_act={self.desc_act})")
+                f"desc_act={self.desc_act}, "
+                f"lm_head_quantized={self.lm_head_quantized})")
 
     @classmethod
     def get_name(cls) -> str:
@@ -116,7 +119,10 @@ class GPTQMarlinConfig(QuantizationConfig):
         group_size = cls.get_from_keys(config, ["group_size"])
         desc_act = cls.get_from_keys(config, ["desc_act"])
         is_sym = cls.get_from_keys(config, ["sym"])
-        return cls(weight_bits, group_size, desc_act, is_sym)
+        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
+                                                 default=False)
+        return cls(weight_bits, group_size, desc_act, is_sym,
+                   lm_head_quantized)
 
     @classmethod
     def override_quantization_method(cls, hf_quant_cfg,
@@ -141,7 +147,8 @@ class GPTQMarlinConfig(QuantizationConfig):
     def get_quant_method(
             self,
             layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
-        if isinstance(layer, LinearBase):
+        if (isinstance(layer, LinearBase) or
+            (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
             return GPTQMarlinLinearMethod(self)
         return None
 

+ 10 - 3
aphrodite/quantization/marlin.py

@@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter
 
 from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
@@ -19,9 +20,11 @@ class MarlinConfig(QuantizationConfig):
     def __init__(
         self,
         group_size: int,
+        lm_head_quantized: bool,
     ) -> None:
         # Group size for the quantization.
         self.group_size = group_size
+        self.lm_head_quantized = lm_head_quantized
         if self.group_size != 128 and self.group_size != -1:
             raise ValueError(
                 "Currently, only group size 128 and -1 (channelwise) "
@@ -48,7 +51,8 @@ class MarlinConfig(QuantizationConfig):
         self.perm_len = 1024
 
     def __repr__(self) -> str:
-        return f"MarlinConfig(group_size={self.group_size})"
+        return (f"MarlinConfig(group_size={self.group_size}, "
+                f"lm_head_quantized={self.lm_head_quantized})")
 
     @classmethod
     def get_name(cls) -> str:
@@ -70,7 +74,9 @@ class MarlinConfig(QuantizationConfig):
     @classmethod
     def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
         group_size = cls.get_from_keys(config, ["group_size"])
-        return cls(group_size)
+        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
+                                                 default=False)
+        return cls(group_size, lm_head_quantized)
 
     @classmethod
     def override_quantization_method(cls, hf_quant_cfg,
@@ -93,7 +99,8 @@ class MarlinConfig(QuantizationConfig):
 
     def get_quant_method(
             self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
-        if isinstance(layer, LinearBase):
+        if (isinstance(layer, LinearBase) or
+            (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
             return MarlinLinearMethod(self)
         return None