Browse Source

feat: Asymmetric Tensor Parallel (#594)

* add utils for getting the partition offset and size for current tp rank

* disable asymmetric TP for quants and lora, handle GQA allocation

* the actual splitting work in the linear layers

* padding size for the vocab/lm_head should be optional

* cache engine and spec decode model runner (kwargs only)

* pass the tp_rank to model runners

* llama support
AlpinDale 7 months ago
parent
commit
5289c14b24

+ 30 - 9
aphrodite/common/config.py

@@ -13,6 +13,7 @@ from aphrodite.common.utils import (cuda_device_count_stateless,
                                     get_cpu_memory, is_cpu, is_hip, is_neuron,
                                     is_openvino, is_tpu, is_xpu,
                                     print_warning_once)
+from aphrodite.distributed import get_current_tp_rank_partition_size
 from aphrodite.modeling.models import ModelRegistry
 from aphrodite.quantization import QUANTIZATION_METHODS
 from aphrodite.transformers_utils.config import get_config, get_hf_text_config
@@ -350,11 +351,13 @@ class ModelConfig:
         total_num_attention_heads = getattr(self.hf_text_config,
                                             "num_attention_heads", 0)
         tensor_parallel_size = parallel_config.tensor_parallel_size
-        if total_num_attention_heads % tensor_parallel_size != 0:
+        if (total_num_attention_heads % tensor_parallel_size != 0
+                and self.quantization is not None):
             raise ValueError(
-                f"Total number of attention heads ({total_num_attention_heads})"
+                f"Total number of attention heads "
+                f"({total_num_attention_heads})"
                 " must be divisible by tensor parallel size "
-                f"({tensor_parallel_size}).")
+                f"({tensor_parallel_size}) when quantization is used.")
 
         pipeline_parallel_size = parallel_config.pipeline_parallel_size
         architectures = getattr(self.hf_config, "architectures", [])
@@ -453,20 +456,32 @@ class ModelConfig:
         # equal to the number of attention heads.
         return self.hf_text_config.num_attention_heads
 
-    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
+    def get_num_kv_heads(self,
+                         parallel_config: "ParallelConfig",
+                         tp_rank: int = 0) -> int:
         """Returns the number of KV heads per GPU."""
         total_num_kv_heads = self.get_total_num_kv_heads()
         # If tensor parallelism is used, we divide the number of KV heads by
         # the tensor parallel size. We will replicate the KV heads in the
         # case where the number of KV heads is smaller than the tensor
         # parallel size so each GPU has at least one KV head.
-        return max(1,
-                   total_num_kv_heads // parallel_config.tensor_parallel_size)
+        result = get_current_tp_rank_partition_size(
+            total_num_kv_heads, tp_rank, parallel_config.tensor_parallel_size)
+        return max(1, result)
 
     def get_num_attention_heads(self,
-                                parallel_config: "ParallelConfig") -> int:
-        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
-        return num_heads // parallel_config.tensor_parallel_size
+                                parallel_config: "ParallelConfig",
+                                tp_rank: int = 0) -> int:
+        if getattr(self.hf_text_config, "num_attention_heads", None) is None:
+            return 0
+
+        num_total_kv_heads = self.get_total_num_kv_heads()
+        num_kv_heads = self.get_num_kv_heads(parallel_config, tp_rank)
+        num_total_attention_heads = self.hf_text_config.num_attention_heads
+        num_heads_per_kv_head = num_total_attention_heads // num_total_kv_heads
+        # For GQA attention we make sure the whole attention head group is
+        # together on the same GPU.
+        return num_kv_heads * num_heads_per_kv_head
 
     def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
         from aphrodite.distributed.utils import get_pp_indices
@@ -1366,6 +1381,11 @@ class LoRAConfig:
         if scheduler_config.chunked_prefill_enabled:
             raise ValueError("LoRA is not supported with chunked prefill yet.")
 
+    def verify_with_parallel_config(self, parallel_config: ParallelConfig):
+        if self.lora_vocab_padding_size % parallel_config.world_size != 0:
+            raise ValueError("LoRA vocab padding size must be divisible "
+                             "by world size.")
+
 
 @dataclass
 class PromptAdapterConfig:
@@ -1631,6 +1651,7 @@ class EngineConfig:
             self.lora_config.verify_with_model_config(self.model_config)
             self.lora_config.verify_with_scheduler_config(
                 self.scheduler_config)
+            self.lora_config.verify_with_parallel_config(self.parallel_config)
         if self.prompt_adapter_config:
             self.prompt_adapter_config.verify_with_model_config(
                 self.model_config)

+ 32 - 0
aphrodite/distributed/parallel_state.py

@@ -1058,3 +1058,35 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
     torch.distributed.all_reduce(is_in_the_same_node, group=pg)
 
     return [x == 1 for x in is_in_the_same_node.tolist()]
+
+
+def get_current_tp_rank_partition_offset(total_size: int,
+                                         tp_rank: Optional[int] = None,
+                                         tp_size: Optional[int] = None,
+                                         multiple_of: int = 1) -> int:
+    if tp_rank is None:
+        tp_rank = get_tensor_model_parallel_rank()
+
+    if tp_size is None:
+        tp_size = get_tensor_model_parallel_world_size()
+
+    assert total_size % multiple_of == 0
+    total_size = total_size // multiple_of
+    return ((total_size // tp_size) * tp_rank +
+            min(total_size % tp_size, tp_rank)) * multiple_of
+
+
+def get_current_tp_rank_partition_size(total_size: int,
+                                       tp_rank: Optional[int] = None,
+                                       tp_size: Optional[int] = None,
+                                       multiple_of: int = 1) -> int:
+    if tp_rank is None:
+        tp_rank = get_tensor_model_parallel_rank()
+
+    if tp_size is None:
+        tp_size = get_tensor_model_parallel_world_size()
+
+    assert total_size % multiple_of == 0
+    total_size = total_size // multiple_of
+    return ((total_size // tp_size) +
+            (total_size % tp_size > tp_rank)) * multiple_of

+ 57 - 33
aphrodite/modeling/layers/linear.py

@@ -6,11 +6,11 @@ import torch.nn.functional as F
 from loguru import logger
 from torch.nn.parameter import Parameter
 
-from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
-                                   split_tensor_along_last_dim,
-                                   tensor_model_parallel_all_gather,
-                                   tensor_model_parallel_all_reduce)
+from aphrodite.distributed import (
+    divide, get_current_tp_rank_partition_offset,
+    get_current_tp_rank_partition_size, get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size, split_tensor_along_last_dim,
+    tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import (QuantizationConfig,
                                                 QuantizeMethodBase)
@@ -254,14 +254,17 @@ class ColumnParallelLinear(LinearBase):
         self.gather_output = gather_output
 
         # Divide the weight matrix along the last dimension.
+        tp_rank = get_tensor_model_parallel_rank()
         tp_size = get_tensor_model_parallel_world_size()
         assert self.quant_method is not None
-        self.output_size_per_partition = divide(self.output_size, tp_size)
+        self.output_size_per_partition = get_current_tp_rank_partition_size(
+            output_size, tp_rank, tp_size)
         self.output_partition_sizes = [self.output_size_per_partition]
         # If QKV or MergedColumn, use output size of each partition.
         if hasattr(self, "output_sizes"):
             self.output_partition_sizes = [
-                divide(output_size, tp_size)
+                get_current_tp_rank_partition_size(output_size, tp_rank,
+                                                   tp_size)
                 for output_size in self.output_sizes
             ]
 
@@ -349,17 +352,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
         quant_config: Quantization configure.
     """
 
-    def __init__(self,
-                 input_size: int,
-                 output_sizes: List[int],
-                 bias: bool = True,
-                 gather_output: bool = False,
-                 skip_bias_add: bool = False,
-                 params_dtype: Optional[torch.dtype] = None,
-                 quant_config: Optional[QuantizationConfig] = None):
+    def __init__(
+        self,
+        input_size: int,
+        output_sizes: List[int],
+        bias: bool = True,
+        gather_output: bool = False,
+        skip_bias_add: bool = False,
+        params_dtype: Optional[torch.dtype] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
         self.output_sizes = output_sizes
-        tp_size = get_tensor_model_parallel_world_size()
-        assert all(output_size % tp_size == 0 for output_size in output_sizes)
         super().__init__(input_size=input_size,
                          output_size=sum(output_sizes),
                          bias=bias,
@@ -417,8 +420,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
         tp_rank = get_tensor_model_parallel_rank()
         tp_size = get_tensor_model_parallel_world_size()
         if output_dim is not None:
-            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
-            shard_size = self.output_sizes[loaded_shard_id] // tp_size
+            shard_offset = sum(
+                get_current_tp_rank_partition_size(output_size, tp_rank,
+                                                   tp_size)
+                for output_size in self.output_sizes[:loaded_shard_id])
+            shard_size = get_current_tp_rank_partition_size(
+                self.output_sizes[loaded_shard_id], tp_rank, tp_size)
             # Special case for quantization.
             # If quantized, we need to adjust the offset and size to account
             # for the packing.
@@ -438,7 +445,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
 
             param_data = param_data.narrow(output_dim, shard_offset,
                                            shard_size)
-            start_idx = tp_rank * shard_size
+            start_idx = get_current_tp_rank_partition_offset(
+                loaded_weight.shape[output_dim], tp_rank, tp_size)
             loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                  shard_size)
         # Special case for AQLM codebooks.
@@ -506,14 +514,17 @@ class QKVParallelLinear(ColumnParallelLinear):
         self.total_num_kv_heads = total_num_kv_heads
         # Divide the weight matrix along the last dimension.
         tp_size = get_tensor_model_parallel_world_size()
-        self.num_heads = divide(self.total_num_heads, tp_size)
+        tp_rank = get_tensor_model_parallel_rank()
+        self.num_heads_per_kv_head = (self.total_num_heads //
+                                      self.total_num_kv_heads)
+        self.num_kv_heads = get_current_tp_rank_partition_size(
+            self.total_num_kv_heads, tp_rank, tp_size)
+        self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
+        self.num_kv_head_replicas = 1
         if tp_size >= self.total_num_kv_heads:
             self.num_kv_heads = 1
             self.num_kv_head_replicas = divide(tp_size,
                                                self.total_num_kv_heads)
-        else:
-            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
-            self.num_kv_head_replicas = 1
         input_size = self.hidden_size
         output_size = (self.num_heads +
                        2 * self.num_kv_heads) * tp_size * self.head_size
@@ -587,13 +598,16 @@ class QKVParallelLinear(ColumnParallelLinear):
             if loaded_shard_id == "q":
                 shard_offset = 0
                 shard_size = self.num_heads * self.head_size
+                multiple_of = self.head_size * self.num_heads_per_kv_head
             elif loaded_shard_id == "k":
                 shard_offset = self.num_heads * self.head_size
                 shard_size = self.num_kv_heads * self.head_size
+                multiple_of = self.head_size
             elif loaded_shard_id == "v":
                 shard_offset = (self.num_heads +
                                 self.num_kv_heads) * self.head_size
                 shard_size = self.num_kv_heads * self.head_size
+                multiple_of = self.head_size
             # Special case for Quantized Weights.
             # If quantized, we need to adjust the offset and size to account
             # for the packing.
@@ -601,6 +615,7 @@ class QKVParallelLinear(ColumnParallelLinear):
             if packed_dim == output_dim:
                 shard_size = shard_size // param.pack_factor
                 shard_offset = shard_offset // param.pack_factor
+                multiple_of = multiple_of // param.pack_factor
 
                 # Special case for Marlin.
                 shard_size, shard_offset = adjust_marlin_shard(
@@ -624,11 +639,11 @@ class QKVParallelLinear(ColumnParallelLinear):
 
             param_data = param_data.narrow(output_dim, shard_offset,
                                            shard_size)
-            if loaded_shard_id == "q":
-                shard_id = tp_rank
-            else:
-                shard_id = tp_rank // self.num_kv_head_replicas
-            start_idx = shard_id * shard_size
+
+            tp_size = get_tensor_model_parallel_world_size()
+            total_size = loaded_weight.shape[output_dim]
+            start_idx = get_current_tp_rank_partition_offset(
+                total_size, tp_rank, tp_size, multiple_of=multiple_of)
             loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                  shard_size)
         # Special case for for AQLM codebooks.
@@ -678,6 +693,8 @@ class RowParallelLinear(LinearBase):
                        We skip adding bias but instead return it.
         params_dtype: Data type for the parameters.
         quant_config: Quantization configure.
+        partition_multiple_of: Partitions will be divided,
+                               so each partition is a multiple of this number.
     """
 
     def __init__(self,
@@ -688,7 +705,8 @@ class RowParallelLinear(LinearBase):
                  skip_bias_add: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
                  reduce_results: bool = True,
-                 quant_config: Optional[QuantizationConfig] = None):
+                 quant_config: Optional[QuantizationConfig] = None,
+                 partition_multiple_of: int = 1):
         super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                          quant_config)
 
@@ -697,7 +715,10 @@ class RowParallelLinear(LinearBase):
 
         # Divide the weight matrix along the last dimension.
         self.tp_size = get_tensor_model_parallel_world_size()
-        self.input_size_per_partition = divide(input_size, self.tp_size)
+        self.tp_rank = get_tensor_model_parallel_rank()
+        self.partition_multiple_of = partition_multiple_of
+        self.input_size_per_partition = get_current_tp_rank_partition_size(
+            input_size, self.tp_rank, self.tp_size, partition_multiple_of)
         assert self.quant_method is not None
         self.quant_method.create_weights(
             layer=self,
@@ -723,12 +744,15 @@ class RowParallelLinear(LinearBase):
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
 
-        tp_rank = get_tensor_model_parallel_rank()
         input_dim = getattr(param, "input_dim", None)
         param_data = param.data
         if input_dim is not None:
             shard_size = param_data.shape[input_dim]
-            start_idx = tp_rank * shard_size
+            start_idx = get_current_tp_rank_partition_offset(
+                self.input_size,
+                self.tp_rank,
+                self.tp_size,
+                multiple_of=self.partition_multiple_of)
             loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                  shard_size)
 

+ 3 - 2
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -167,10 +167,11 @@ class VocabParallelEmbedding(torch.nn.Module):
                  embedding_dim: int,
                  params_dtype: Optional[torch.dtype] = None,
                  org_num_embeddings: Optional[int] = None,
-                 padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
+                 padding_size: Optional[int] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
 
+        padding_size = padding_size or get_tensor_model_parallel_world_size()
         # Keep the input dimensions.
         tp_rank = get_tensor_model_parallel_rank()
         self.tp_size = get_tensor_model_parallel_world_size()
@@ -379,7 +380,7 @@ class ParallelLMHead(VocabParallelEmbedding):
                  bias: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
                  org_num_embeddings: Optional[int] = None,
-                 padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
+                 padding_size: Optional[int] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__(num_embeddings, embedding_dim, params_dtype,
                          org_num_embeddings, padding_size, quant_config)

+ 12 - 14
aphrodite/modeling/models/llama.py

@@ -31,7 +31,8 @@ from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.utils import is_hip
-from aphrodite.distributed import (get_pp_group,
+from aphrodite.distributed import (get_current_tp_rank_partition_size,
+                                   get_pp_group,
                                    get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -43,7 +44,7 @@ from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+    ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
     default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
 from aphrodite.modeling.models.interfaces import SupportsLoRA
@@ -103,19 +104,15 @@ class LlamaAttention(nn.Module):
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
+        tp_rank = get_tensor_model_parallel_rank()
         self.total_num_heads = num_heads
-        assert self.total_num_heads % tp_size == 0
-        self.num_heads = self.total_num_heads // tp_size
         self.total_num_kv_heads = num_kv_heads
-        if self.total_num_kv_heads >= tp_size:
-            # Number of KV heads is greater than TP size, so we partition
-            # the KV heads across multiple tensor parallel GPUs.
-            assert self.total_num_kv_heads % tp_size == 0
-        else:
-            # Number of KV heads is less than TP size, so we replicate
-            # the KV heads across multiple tensor parallel GPUs.
-            assert tp_size % self.total_num_kv_heads == 0
-        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.num_kv_heads = max(
+            1,
+            get_current_tp_rank_partition_size(self.total_num_kv_heads,
+                                               tp_rank, tp_size))
+        num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads
+        self.num_heads = self.num_kv_heads * num_heads_per_kv_head
         # MistralConfig has an optional head_dim introduced by Mistral-Nemo
         self.head_dim = getattr(config, "head_dim",
                                 self.hidden_size // self.total_num_heads)
@@ -138,6 +135,7 @@ class LlamaAttention(nn.Module):
             output_size=hidden_size,
             bias=bias,
             quant_config=quant_config,
+            partition_multiple_of=num_heads_per_kv_head * self.head_dim,
         )
 
         self.rotary_emb = get_rope(
@@ -368,7 +366,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
             self.unpadded_vocab_size,
             config.hidden_size,
             org_num_embeddings=config.vocab_size,
-            padding_size=DEFAULT_VOCAB_PADDING_SIZE
+            padding_size=None
             # We need bigger padding if using lora for kernel
             # compatibility
             if not lora_config else lora_config.lora_vocab_padding_size,

+ 2 - 0
aphrodite/spec_decode/draft_model_runner.py

@@ -47,6 +47,7 @@ class TP1DraftModelRunner(ModelRunner):
         multimodal_config: Optional[MultiModalConfig] = None,
         prompt_adapter_config: Optional[PromptAdapterConfig] = None,
         return_hidden_states: bool = False,
+        **kwargs,
     ):
         if return_hidden_states:
             raise ValueError(
@@ -66,6 +67,7 @@ class TP1DraftModelRunner(ModelRunner):
             multimodal_config=multimodal_config,
             prompt_adapter_config=prompt_adapter_config,
             return_hidden_states=return_hidden_states,
+            **kwargs,
         )
 
         # TODO: Remove this cache when we are able to update model_input

+ 6 - 3
aphrodite/task_handler/cache_engine.py

@@ -24,6 +24,7 @@ class CacheEngine:
         model_config: ModelConfig,
         parallel_config: ParallelConfig,
         device_config: DeviceConfig,
+        tp_rank: int = 0,
     ) -> None:
         self.cache_config = cache_config
         self.model_config = model_config
@@ -34,7 +35,8 @@ class CacheEngine:
         # Models like Jamba, have mixed typed layers, E.g Mamba
         self.num_attention_layers = model_config.get_num_attention_layers(
             parallel_config)
-        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
+        self.num_kv_heads = model_config.get_num_kv_heads(
+            parallel_config, tp_rank)
 
         self.block_size = cache_config.block_size
         self.num_gpu_blocks = cache_config.num_gpu_blocks
@@ -51,7 +53,7 @@ class CacheEngine:
 
         # Get attention backend.
         self.attn_backend = get_attn_backend(
-            model_config.get_num_attention_heads(parallel_config),
+            model_config.get_num_attention_heads(parallel_config, tp_rank),
             self.head_size,
             self.num_kv_heads,
             model_config.get_sliding_window(),
@@ -104,9 +106,10 @@ class CacheEngine:
         cache_config: CacheConfig,
         model_config: ModelConfig,
         parallel_config: ParallelConfig,
+        tp_rank: int = 0,
     ) -> int:
         head_size = model_config.get_head_size()
-        num_heads = model_config.get_num_kv_heads(parallel_config)
+        num_heads = model_config.get_num_kv_heads(parallel_config, tp_rank)
         num_attention_layers = model_config.get_num_attention_layers(
             parallel_config)
 

+ 3 - 1
aphrodite/task_handler/embedding_model_runner.py

@@ -41,6 +41,7 @@ class EmbeddingModelRunner(
         prompt_adapter_config: Optional[PromptAdapterConfig] = None,
         is_driver_worker: bool = False,
         multimodal_config: Optional[MultiModalConfig] = None,
+        tp_rank: int = 0,
     ):
         super().__init__(model_config,
                          parallel_config,
@@ -52,7 +53,8 @@ class EmbeddingModelRunner(
                          kv_cache_dtype=kv_cache_dtype,
                          is_driver_worker=is_driver_worker,
                          prompt_adapter_config=prompt_adapter_config,
-                         multimodal_config=multimodal_config)
+                         multimodal_config=multimodal_config,
+                         tp_rank=tp_rank)
 
     @torch.inference_mode()
     def execute_model(

+ 7 - 4
aphrodite/task_handler/model_runner.py

@@ -187,6 +187,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         prompt_adapter_config: Optional[PromptAdapterConfig] = None,
         multimodal_config: Optional[MultiModalConfig] = None,
         return_hidden_states: bool = False,
+        tp_rank: int = 0,
     ):
         self.model_config = model_config
         self.parallel_config = parallel_config
@@ -203,6 +204,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         self.device = self.device_config.device
         self.pin_memory = is_pin_memory_available()
 
+        self.tp_rank = tp_rank
         self.kv_cache_dtype = kv_cache_dtype
         self.sliding_window = model_config.get_sliding_window()
         self.block_size = cache_config.block_size
@@ -226,11 +228,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
             (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
             dtype=np.int32)
         num_attn_heads = self.model_config.get_num_attention_heads(
-            self.parallel_config)
+            self.parallel_config, self.tp_rank)
         self.attn_backend = get_attn_backend(
             num_attn_heads,
             self.model_config.get_head_size(),
-            self.model_config.get_num_kv_heads(self.parallel_config),
+            self.model_config.get_num_kv_heads(self.parallel_config,
+                                               self.tp_rank),
             self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,
@@ -780,9 +783,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
                 paged_kv_indices=paged_kv_indices_tensor,
                 paged_kv_last_page_len=paged_kv_last_page_len_tensor,
                 num_qo_heads=self.model_config.get_num_attention_heads(
-                    self.parallel_config),
+                    self.parallel_config, self.tp_rank),
                 num_kv_heads=self.model_config.get_num_kv_heads(
-                    self.parallel_config),
+                    self.parallel_config, self.tp_rank),
                 head_dim=self.model_config.get_head_size(),
                 page_size=self.block_size,
                 seq_start_loc=seq_start_loc,

+ 2 - 1
aphrodite/task_handler/worker.py

@@ -102,6 +102,7 @@ class Worker(LocalOrDistributedWorkerBase):
             is_driver_worker=is_driver_worker,
             prompt_adapter_config=prompt_adapter_config,
             multimodal_config=multimodal_config,
+            tp_rank=self.rank,
             **speculative_args,
         )
         # Uninitialized cache engine. Will be initialized by
@@ -226,7 +227,7 @@ class Worker(LocalOrDistributedWorkerBase):
         assert self.cache_config.num_gpu_blocks is not None
         self.cache_engine = [
             CacheEngine(self.cache_config, self.model_config,
-                        self.parallel_config, self.device_config)
+                        self.parallel_config, self.device_config, self.rank)
             for _ in range(self.parallel_config.pipeline_parallel_size)
         ]
         self.gpu_cache = [