utils.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. from typing import Sequence
  7. import torch
  8. def ensure_divisibility(numerator, denominator):
  9. """Ensure that numerator is divisible by the denominator."""
  10. assert numerator % denominator == 0, "{} is not divisible by {}".format(
  11. numerator, denominator)
  12. def divide(numerator, denominator):
  13. """Ensure that numerator is divisible by the denominator and return
  14. the division value."""
  15. ensure_divisibility(numerator, denominator)
  16. return numerator // denominator
  17. def split_tensor_along_last_dim(
  18. tensor: torch.Tensor,
  19. num_partitions: int,
  20. contiguous_split_chunks: bool = False,
  21. ) -> Sequence[torch.Tensor]:
  22. """ Split a tensor along its last dimension.
  23. Arguments:
  24. tensor: input tensor.
  25. num_partitions: number of partitions to split the tensor
  26. contiguous_split_chunks: If True, make each chunk contiguous
  27. in memory.
  28. Returns:
  29. A list of Tensors
  30. """
  31. # Get the size and dimension.
  32. last_dim = tensor.dim() - 1
  33. last_dim_size = divide(tensor.size()[last_dim], num_partitions)
  34. # Split.
  35. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  36. # NOTE: torch.split does not create contiguous tensors by default.
  37. if contiguous_split_chunks:
  38. return tuple(chunk.contiguous() for chunk in tensor_list)
  39. return tensor_list