123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The vLLM team.
- # Adapted from
- # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
- import json
- import os
- from typing import Dict, Optional, Sequence
- import torch
- import torch.distributed as dist
- from loguru import logger
- from .parallel_state import get_cpu_world_group, get_local_rank
- def ensure_divisibility(numerator, denominator):
- """Ensure that numerator is divisible by the denominator."""
- assert numerator % denominator == 0, "{} is not divisible by {}".format(
- numerator, denominator)
- def divide(numerator, denominator):
- """Ensure that numerator is divisible by the denominator and return
- the division value."""
- ensure_divisibility(numerator, denominator)
- return numerator // denominator
- def split_tensor_along_last_dim(
- tensor: torch.Tensor,
- num_partitions: int,
- contiguous_split_chunks: bool = False,
- ) -> Sequence[torch.Tensor]:
- """ Split a tensor along its last dimension.
- Arguments:
- tensor: input tensor.
- num_partitions: number of partitions to split the tensor
- contiguous_split_chunks: If True, make each chunk contiguous
- in memory.
- Returns:
- A list of Tensors
- """
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- last_dim_size = divide(tensor.size()[last_dim], num_partitions)
- # Split.
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
- # NOTE: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
- return tensor_list
- # code partly borrowed from
- # https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
- # License: MIT
- def _can_actually_p2p(idx_a, idx_b):
- dev_i = f"cuda:{idx_a}"
- dev_j = f"cuda:{idx_b}"
- a = torch.randn(5, device=dev_i) + 123.0
- b = a.to(dev_j)
- c = b.to(dev_i)
- return torch.all(a == c).cpu().item()
- # why do we need this cache?
- # 1. we can have runtime checks for P2P access, where every process checks
- # P2P access to all other GPUs. Unfortunately, the test might cost many
- # (world_size * world_size) cuda context, and reduce the memory available
- # for the model.
- # 2. alternatively, we can have a p2p map that is generated by the master
- # process and broadcasted to all other processes. This still requires
- # #world_size of cuda context, belonging to the master process, on each GPU.
- # 3. we can have a cache file, that records the p2p access status. The first
- # time the master process checks the p2p access, it will generate the cache
- # file, at the cost of #world_size of cuda context. Later on, all processes
- # can read the cache file to check the p2p access status without any cost of
- # additional cuda context.
- # Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
- # can have different cache files for different CUDA_VISIBLE_DEVICES settings,
- # e.g. used by different aphrodite engines. The device id in the cache file is
- # a **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
- # of visible devices in the aphrodite engine.
- _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
- def gpu_p2p_access_check(i: int, j: int) -> bool:
- """Check if GPU i can access GPU j."""
- # if the cache variable is already calculated,
- # read from the cache instead of checking it again
- global _gpu_p2p_access_cache
- if _gpu_p2p_access_cache is not None:
- return _gpu_p2p_access_cache[f"{i}->{j}"]
- is_distributed = dist.is_initialized()
- num_dev = torch.cuda.device_count()
- cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
- if cuda_visible_devices is None:
- cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
- path = os.path.expanduser(
- f"~/.config/aphrodite/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
- )
- os.makedirs(os.path.dirname(path), exist_ok=True)
- if (not is_distributed or get_local_rank() == 0) \
- and (not os.path.exists(path)):
- # only the local master process (with local_rank == 0) can
- # enter this block to calculate the cache
- logger.info(f"generating GPU P2P access cache for in {path}")
- cache = {}
- for _i in range(num_dev):
- for _j in range(num_dev):
- # on some platforms, P2P support might be buggy and we need
- # additional checks.
- cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
- _i, _j) and _can_actually_p2p(_i, _j)
- with open(path, "w") as f:
- json.dump(cache, f, indent=4)
- if is_distributed:
- cpu_world_group = get_cpu_world_group()
- dist.barrier(cpu_world_group)
- logger.info(f"reading GPU P2P access cache from {path}")
- with open(path, "r") as f:
- cache = json.load(f)
- _gpu_p2p_access_cache = cache
- return _gpu_p2p_access_cache[f"{i}->{j}"]
|