utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. import json
  7. import os
  8. from typing import Dict, Optional, Sequence
  9. import torch
  10. import torch.distributed as dist
  11. from loguru import logger
  12. from .parallel_state import get_cpu_world_group, get_local_rank
  13. def ensure_divisibility(numerator, denominator):
  14. """Ensure that numerator is divisible by the denominator."""
  15. assert numerator % denominator == 0, "{} is not divisible by {}".format(
  16. numerator, denominator)
  17. def divide(numerator, denominator):
  18. """Ensure that numerator is divisible by the denominator and return
  19. the division value."""
  20. ensure_divisibility(numerator, denominator)
  21. return numerator // denominator
  22. def split_tensor_along_last_dim(
  23. tensor: torch.Tensor,
  24. num_partitions: int,
  25. contiguous_split_chunks: bool = False,
  26. ) -> Sequence[torch.Tensor]:
  27. """ Split a tensor along its last dimension.
  28. Arguments:
  29. tensor: input tensor.
  30. num_partitions: number of partitions to split the tensor
  31. contiguous_split_chunks: If True, make each chunk contiguous
  32. in memory.
  33. Returns:
  34. A list of Tensors
  35. """
  36. # Get the size and dimension.
  37. last_dim = tensor.dim() - 1
  38. last_dim_size = divide(tensor.size()[last_dim], num_partitions)
  39. # Split.
  40. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  41. # NOTE: torch.split does not create contiguous tensors by default.
  42. if contiguous_split_chunks:
  43. return tuple(chunk.contiguous() for chunk in tensor_list)
  44. return tensor_list
  45. # code partly borrowed from
  46. # https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
  47. # License: MIT
  48. def _can_actually_p2p(idx_a, idx_b):
  49. dev_i = f"cuda:{idx_a}"
  50. dev_j = f"cuda:{idx_b}"
  51. a = torch.randn(5, device=dev_i) + 123.0
  52. b = a.to(dev_j)
  53. c = b.to(dev_i)
  54. return torch.all(a == c).cpu().item()
  55. # why do we need this cache?
  56. # 1. we can have runtime checks for P2P access, where every process checks
  57. # P2P access to all other GPUs. Unfortunately, the test might cost many
  58. # (world_size * world_size) cuda context, and reduce the memory available
  59. # for the model.
  60. # 2. alternatively, we can have a p2p map that is generated by the master
  61. # process and broadcasted to all other processes. This still requires
  62. # #world_size of cuda context, belonging to the master process, on each GPU.
  63. # 3. we can have a cache file, that records the p2p access status. The first
  64. # time the master process checks the p2p access, it will generate the cache
  65. # file, at the cost of #world_size of cuda context. Later on, all processes
  66. # can read the cache file to check the p2p access status without any cost of
  67. # additional cuda context.
  68. # Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
  69. # can have different cache files for different CUDA_VISIBLE_DEVICES settings,
  70. # e.g. used by different aphrodite engines. The device id in the cache file is
  71. # a **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
  72. # of visible devices in the aphrodite engine.
  73. _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
  74. def gpu_p2p_access_check(i: int, j: int) -> bool:
  75. """Check if GPU i can access GPU j."""
  76. # if the cache variable is already calculated,
  77. # read from the cache instead of checking it again
  78. global _gpu_p2p_access_cache
  79. if _gpu_p2p_access_cache is not None:
  80. return _gpu_p2p_access_cache[f"{i}->{j}"]
  81. is_distributed = dist.is_initialized()
  82. num_dev = torch.cuda.device_count()
  83. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  84. if cuda_visible_devices is None:
  85. cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
  86. path = os.path.expanduser(
  87. f"~/.config/aphrodite/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
  88. )
  89. os.makedirs(os.path.dirname(path), exist_ok=True)
  90. if (not is_distributed or get_local_rank() == 0) \
  91. and (not os.path.exists(path)):
  92. # only the local master process (with local_rank == 0) can
  93. # enter this block to calculate the cache
  94. logger.info(f"generating GPU P2P access cache for in {path}")
  95. cache = {}
  96. for _i in range(num_dev):
  97. for _j in range(num_dev):
  98. # on some platforms, P2P support might be buggy and we need
  99. # additional checks.
  100. cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
  101. _i, _j) and _can_actually_p2p(_i, _j)
  102. with open(path, "w") as f:
  103. json.dump(cache, f, indent=4)
  104. if is_distributed:
  105. cpu_world_group = get_cpu_world_group()
  106. dist.barrier(cpu_world_group)
  107. logger.info(f"reading GPU P2P access cache from {path}")
  108. with open(path, "r") as f:
  109. cache = json.load(f)
  110. _gpu_p2p_access_cache = cache
  111. return _gpu_p2p_access_cache[f"{i}->{j}"]