# Copyright 2023 The PygmalionAI team. # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. from typing import Sequence, Tuple import torch import aphrodite.common.envs as envs APHRODITE_PP_LAYER_PARTITION = envs.APHRODITE_PP_LAYER_PARTITION def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format( numerator, denominator) def divide(numerator, denominator): """Ensure that numerator is divisible by the denominator and return the division value.""" ensure_divisibility(numerator, denominator) return numerator // denominator def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, ) -> Sequence[torch.Tensor]: """ Split a tensor along its last dimension. Arguments: tensor: input tensor. num_partitions: number of partitions to split the tensor contiguous_split_chunks: If True, make each chunk contiguous in memory. Returns: A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 last_dim_size = divide(tensor.size()[last_dim], num_partitions) # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # NOTE: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list def get_pp_indices(num_hidden_layers: int, pp_rank: int, pp_size: int) -> Tuple[int, int]: """Try to evenly distribute layers across partitions. If the number of layers is not divisible by the number of partitions, the last partition will have the remaining layers. """ partition_list_str = APHRODITE_PP_LAYER_PARTITION if partition_list_str is not None: try: partitions = [ int(layer) for layer in partition_list_str.split(",") ] except ValueError as err: raise ValueError("Invalid partition string: {}".format( partition_list_str)) from err if len(partitions) != pp_size: raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") if sum(partitions) != num_hidden_layers: raise ValueError( f"{sum(partitions)=} does not match {num_hidden_layers=}.") start_layer = sum(partitions[:pp_rank]) end_layer = start_layer + partitions[pp_rank] else: layers_per_partition = num_hidden_layers // pp_size start_layer = pp_rank * layers_per_partition end_layer = start_layer + layers_per_partition if pp_rank == pp_size - 1: end_layer = num_hidden_layers return (start_layer, end_layer)