1
0

utils.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. from typing import Sequence, Tuple
  7. import torch
  8. def ensure_divisibility(numerator, denominator):
  9. """Ensure that numerator is divisible by the denominator."""
  10. assert numerator % denominator == 0, "{} is not divisible by {}".format(
  11. numerator, denominator)
  12. def divide(numerator, denominator):
  13. """Ensure that numerator is divisible by the denominator and return
  14. the division value."""
  15. ensure_divisibility(numerator, denominator)
  16. return numerator // denominator
  17. def split_tensor_along_last_dim(
  18. tensor: torch.Tensor,
  19. num_partitions: int,
  20. contiguous_split_chunks: bool = False,
  21. ) -> Sequence[torch.Tensor]:
  22. """ Split a tensor along its last dimension.
  23. Arguments:
  24. tensor: input tensor.
  25. num_partitions: number of partitions to split the tensor
  26. contiguous_split_chunks: If True, make each chunk contiguous
  27. in memory.
  28. Returns:
  29. A list of Tensors
  30. """
  31. # Get the size and dimension.
  32. last_dim = tensor.dim() - 1
  33. last_dim_size = divide(tensor.size()[last_dim], num_partitions)
  34. # Split.
  35. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  36. # NOTE: torch.split does not create contiguous tensors by default.
  37. if contiguous_split_chunks:
  38. return tuple(chunk.contiguous() for chunk in tensor_list)
  39. return tensor_list
  40. def get_pp_indices(num_hidden_layers: int, pp_rank: int,
  41. pp_size: int) -> Tuple[int, int]:
  42. """Try to evenly distribute layers across partitions.
  43. If the number of layers is not divisible by the number of partitions,
  44. the last partition will have the remaining layers.
  45. """
  46. layers_per_partition = num_hidden_layers // pp_size
  47. start_layer = pp_rank * layers_per_partition
  48. end_layer = start_layer + layers_per_partition
  49. if pp_rank == pp_size - 1:
  50. end_layer = num_hidden_layers
  51. return (start_layer, end_layer)