12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The vLLM team.
- # Adapted from
- # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
- from typing import Sequence
- import torch
- def ensure_divisibility(numerator, denominator):
- """Ensure that numerator is divisible by the denominator."""
- assert numerator % denominator == 0, "{} is not divisible by {}".format(
- numerator, denominator)
- def divide(numerator, denominator):
- """Ensure that numerator is divisible by the denominator and return
- the division value."""
- ensure_divisibility(numerator, denominator)
- return numerator // denominator
- def split_tensor_along_last_dim(
- tensor: torch.Tensor,
- num_partitions: int,
- contiguous_split_chunks: bool = False,
- ) -> Sequence[torch.Tensor]:
- """ Split a tensor along its last dimension.
- Arguments:
- tensor: input tensor.
- num_partitions: number of partitions to split the tensor
- contiguous_split_chunks: If True, make each chunk contiguous
- in memory.
- Returns:
- A list of Tensors
- """
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- last_dim_size = divide(tensor.size()[last_dim], num_partitions)
- # Split.
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
- # NOTE: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
- return tensor_list
|