|
@@ -6,11 +6,11 @@ import torch.nn.functional as F
|
|
|
from loguru import logger
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
-from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
|
|
|
- get_tensor_model_parallel_world_size,
|
|
|
- split_tensor_along_last_dim,
|
|
|
- tensor_model_parallel_all_gather,
|
|
|
- tensor_model_parallel_all_reduce)
|
|
|
+from aphrodite.distributed import (
|
|
|
+ divide, get_current_tp_rank_partition_offset,
|
|
|
+ get_current_tp_rank_partition_size, get_tensor_model_parallel_rank,
|
|
|
+ get_tensor_model_parallel_world_size, split_tensor_along_last_dim,
|
|
|
+ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
|
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
|
from aphrodite.quantization.base_config import (QuantizationConfig,
|
|
|
QuantizeMethodBase)
|
|
@@ -254,14 +254,17 @@ class ColumnParallelLinear(LinearBase):
|
|
|
self.gather_output = gather_output
|
|
|
|
|
|
# Divide the weight matrix along the last dimension.
|
|
|
+ tp_rank = get_tensor_model_parallel_rank()
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
assert self.quant_method is not None
|
|
|
- self.output_size_per_partition = divide(self.output_size, tp_size)
|
|
|
+ self.output_size_per_partition = get_current_tp_rank_partition_size(
|
|
|
+ output_size, tp_rank, tp_size)
|
|
|
self.output_partition_sizes = [self.output_size_per_partition]
|
|
|
# If QKV or MergedColumn, use output size of each partition.
|
|
|
if hasattr(self, "output_sizes"):
|
|
|
self.output_partition_sizes = [
|
|
|
- divide(output_size, tp_size)
|
|
|
+ get_current_tp_rank_partition_size(output_size, tp_rank,
|
|
|
+ tp_size)
|
|
|
for output_size in self.output_sizes
|
|
|
]
|
|
|
|
|
@@ -349,17 +352,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
|
quant_config: Quantization configure.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self,
|
|
|
- input_size: int,
|
|
|
- output_sizes: List[int],
|
|
|
- bias: bool = True,
|
|
|
- gather_output: bool = False,
|
|
|
- skip_bias_add: bool = False,
|
|
|
- params_dtype: Optional[torch.dtype] = None,
|
|
|
- quant_config: Optional[QuantizationConfig] = None):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ input_size: int,
|
|
|
+ output_sizes: List[int],
|
|
|
+ bias: bool = True,
|
|
|
+ gather_output: bool = False,
|
|
|
+ skip_bias_add: bool = False,
|
|
|
+ params_dtype: Optional[torch.dtype] = None,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ):
|
|
|
self.output_sizes = output_sizes
|
|
|
- tp_size = get_tensor_model_parallel_world_size()
|
|
|
- assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
|
|
super().__init__(input_size=input_size,
|
|
|
output_size=sum(output_sizes),
|
|
|
bias=bias,
|
|
@@ -417,8 +420,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
if output_dim is not None:
|
|
|
- shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
|
|
- shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
|
|
+ shard_offset = sum(
|
|
|
+ get_current_tp_rank_partition_size(output_size, tp_rank,
|
|
|
+ tp_size)
|
|
|
+ for output_size in self.output_sizes[:loaded_shard_id])
|
|
|
+ shard_size = get_current_tp_rank_partition_size(
|
|
|
+ self.output_sizes[loaded_shard_id], tp_rank, tp_size)
|
|
|
# Special case for quantization.
|
|
|
# If quantized, we need to adjust the offset and size to account
|
|
|
# for the packing.
|
|
@@ -438,7 +445,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
|
|
|
|
param_data = param_data.narrow(output_dim, shard_offset,
|
|
|
shard_size)
|
|
|
- start_idx = tp_rank * shard_size
|
|
|
+ start_idx = get_current_tp_rank_partition_offset(
|
|
|
+ loaded_weight.shape[output_dim], tp_rank, tp_size)
|
|
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
|
shard_size)
|
|
|
# Special case for AQLM codebooks.
|
|
@@ -506,14 +514,17 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
|
self.total_num_kv_heads = total_num_kv_heads
|
|
|
# Divide the weight matrix along the last dimension.
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
- self.num_heads = divide(self.total_num_heads, tp_size)
|
|
|
+ tp_rank = get_tensor_model_parallel_rank()
|
|
|
+ self.num_heads_per_kv_head = (self.total_num_heads //
|
|
|
+ self.total_num_kv_heads)
|
|
|
+ self.num_kv_heads = get_current_tp_rank_partition_size(
|
|
|
+ self.total_num_kv_heads, tp_rank, tp_size)
|
|
|
+ self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
|
|
|
+ self.num_kv_head_replicas = 1
|
|
|
if tp_size >= self.total_num_kv_heads:
|
|
|
self.num_kv_heads = 1
|
|
|
self.num_kv_head_replicas = divide(tp_size,
|
|
|
self.total_num_kv_heads)
|
|
|
- else:
|
|
|
- self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
|
|
- self.num_kv_head_replicas = 1
|
|
|
input_size = self.hidden_size
|
|
|
output_size = (self.num_heads +
|
|
|
2 * self.num_kv_heads) * tp_size * self.head_size
|
|
@@ -587,13 +598,16 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
|
if loaded_shard_id == "q":
|
|
|
shard_offset = 0
|
|
|
shard_size = self.num_heads * self.head_size
|
|
|
+ multiple_of = self.head_size * self.num_heads_per_kv_head
|
|
|
elif loaded_shard_id == "k":
|
|
|
shard_offset = self.num_heads * self.head_size
|
|
|
shard_size = self.num_kv_heads * self.head_size
|
|
|
+ multiple_of = self.head_size
|
|
|
elif loaded_shard_id == "v":
|
|
|
shard_offset = (self.num_heads +
|
|
|
self.num_kv_heads) * self.head_size
|
|
|
shard_size = self.num_kv_heads * self.head_size
|
|
|
+ multiple_of = self.head_size
|
|
|
# Special case for Quantized Weights.
|
|
|
# If quantized, we need to adjust the offset and size to account
|
|
|
# for the packing.
|
|
@@ -601,6 +615,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
|
if packed_dim == output_dim:
|
|
|
shard_size = shard_size // param.pack_factor
|
|
|
shard_offset = shard_offset // param.pack_factor
|
|
|
+ multiple_of = multiple_of // param.pack_factor
|
|
|
|
|
|
# Special case for Marlin.
|
|
|
shard_size, shard_offset = adjust_marlin_shard(
|
|
@@ -624,11 +639,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
|
|
|
|
param_data = param_data.narrow(output_dim, shard_offset,
|
|
|
shard_size)
|
|
|
- if loaded_shard_id == "q":
|
|
|
- shard_id = tp_rank
|
|
|
- else:
|
|
|
- shard_id = tp_rank // self.num_kv_head_replicas
|
|
|
- start_idx = shard_id * shard_size
|
|
|
+
|
|
|
+ tp_size = get_tensor_model_parallel_world_size()
|
|
|
+ total_size = loaded_weight.shape[output_dim]
|
|
|
+ start_idx = get_current_tp_rank_partition_offset(
|
|
|
+ total_size, tp_rank, tp_size, multiple_of=multiple_of)
|
|
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
|
shard_size)
|
|
|
# Special case for for AQLM codebooks.
|
|
@@ -678,6 +693,8 @@ class RowParallelLinear(LinearBase):
|
|
|
We skip adding bias but instead return it.
|
|
|
params_dtype: Data type for the parameters.
|
|
|
quant_config: Quantization configure.
|
|
|
+ partition_multiple_of: Partitions will be divided,
|
|
|
+ so each partition is a multiple of this number.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
@@ -688,7 +705,8 @@ class RowParallelLinear(LinearBase):
|
|
|
skip_bias_add: bool = False,
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
|
reduce_results: bool = True,
|
|
|
- quant_config: Optional[QuantizationConfig] = None):
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ partition_multiple_of: int = 1):
|
|
|
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
|
|
quant_config)
|
|
|
|
|
@@ -697,7 +715,10 @@ class RowParallelLinear(LinearBase):
|
|
|
|
|
|
# Divide the weight matrix along the last dimension.
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
|
- self.input_size_per_partition = divide(input_size, self.tp_size)
|
|
|
+ self.tp_rank = get_tensor_model_parallel_rank()
|
|
|
+ self.partition_multiple_of = partition_multiple_of
|
|
|
+ self.input_size_per_partition = get_current_tp_rank_partition_size(
|
|
|
+ input_size, self.tp_rank, self.tp_size, partition_multiple_of)
|
|
|
assert self.quant_method is not None
|
|
|
self.quant_method.create_weights(
|
|
|
layer=self,
|
|
@@ -723,12 +744,15 @@ class RowParallelLinear(LinearBase):
|
|
|
|
|
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
|
|
|
|
- tp_rank = get_tensor_model_parallel_rank()
|
|
|
input_dim = getattr(param, "input_dim", None)
|
|
|
param_data = param.data
|
|
|
if input_dim is not None:
|
|
|
shard_size = param_data.shape[input_dim]
|
|
|
- start_idx = tp_rank * shard_size
|
|
|
+ start_idx = get_current_tp_rank_partition_offset(
|
|
|
+ self.input_size,
|
|
|
+ self.tp_rank,
|
|
|
+ self.tp_size,
|
|
|
+ multiple_of=self.partition_multiple_of)
|
|
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
|
|
shard_size)
|
|
|
|