utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import itertools
  2. from collections import UserDict
  3. from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
  4. Union, overload)
  5. import torch
  6. import torch.nn as nn
  7. from torch.func import functional_call
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
  10. SchedulerConfig)
  11. from aphrodite.common.sequence import IntermediateTensors
  12. from aphrodite.common.utils import is_pin_memory_available, progress_bar
  13. from aphrodite.modeling.model_loader.loader import build_model
  14. from aphrodite.modeling.models import ModelRegistry
  15. from aphrodite.multimodal.base import NestedTensors
  16. from aphrodite.quantization import QuantizationConfig
  17. class WeightsGroup(UserDict):
  18. """
  19. Wraps grouped weights dictionary for a more informative error message
  20. when attempting to access a weight component that does not exist.
  21. """
  22. def __getitem__(self, key: str) -> int:
  23. try:
  24. return super().__getitem__(key)
  25. except KeyError as exc:
  26. msg = (f"There is no weights named with the prefix: {key}. "
  27. f"Available prefix: {set(self.keys())}")
  28. raise KeyError(msg) from exc
  29. def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
  30. prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
  31. """
  32. Helper function to load weights for inner aphrodite models.
  33. See also:
  34. :ref:`init_aphrodite_registered_model`
  35. """
  36. weights_list = list(weights)
  37. for name, loaded_weight in progress_bar(weights_list,
  38. desc="Loading modules..."):
  39. name = name.split(".")
  40. if prefix == name.pop(0):
  41. name = ".".join(name)
  42. yield name, loaded_weight
  43. def group_weights_with_prefix(
  44. weights: Iterable[Tuple[str, torch.Tensor]]
  45. ) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
  46. """
  47. Helper function to group weights with prefix
  48. """
  49. init_weights, repeated_weights = itertools.tee(weights, 2)
  50. weights_prefix = {name.split(".")[0] for name, _ in init_weights}
  51. repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
  52. return WeightsGroup({
  53. prefix: filter_weights(component, prefix)
  54. for component, prefix in zip(repeated_weights, weights_prefix)
  55. })
  56. def init_aphrodite_registered_model(
  57. hf_config: PretrainedConfig,
  58. cache_config: Optional[CacheConfig],
  59. quant_config: Optional[QuantizationConfig],
  60. *,
  61. lora_config: Optional[LoRAConfig] = None,
  62. multimodal_config: Optional[MultiModalConfig] = None,
  63. scheduler_config: Optional[SchedulerConfig] = None,
  64. ) -> nn.Module:
  65. """
  66. Helper function to initialize an inner model registered to aphrodite,
  67. based on the arguments passed to the outer aphrodite model.
  68. """
  69. model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
  70. return build_model(
  71. model_class,
  72. hf_config,
  73. cache_config,
  74. quant_config,
  75. lora_config=lora_config,
  76. multimodal_config=multimodal_config,
  77. scheduler_config=scheduler_config,
  78. )
  79. @overload
  80. def flatten_bn(x: torch.Tensor) -> torch.Tensor:
  81. ...
  82. @overload
  83. def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
  84. ...
  85. @overload
  86. def flatten_bn(
  87. x: Union[List[torch.Tensor], torch.Tensor],
  88. *,
  89. concat: Literal[True],
  90. ) -> torch.Tensor:
  91. ...
  92. def flatten_bn(
  93. x: Union[List[torch.Tensor], torch.Tensor],
  94. *,
  95. concat: bool = False,
  96. ) -> Union[List[torch.Tensor], torch.Tensor]:
  97. """
  98. Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
  99. The input tensor should have shape ``(B, N, ...)```.
  100. """
  101. if isinstance(x, torch.Tensor):
  102. return x.flatten(0, 1)
  103. if concat:
  104. return torch.cat(x)
  105. return [x_n for x_b in x for x_n in x_b]
  106. def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
  107. """
  108. Recursively flattens and concatenates NestedTensors on all but the last
  109. dimension.
  110. """
  111. if isinstance(embeddings, torch.Tensor):
  112. # Flatten all but the last dimension.
  113. return embeddings.flatten(0, -2)
  114. return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
  115. def _embedding_count_expression(embeddings: NestedTensors) -> str:
  116. """
  117. Constructs a debugging representation of the number of embeddings in the
  118. NestedTensors.
  119. """
  120. if isinstance(embeddings, torch.Tensor):
  121. return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
  122. return " + ".join(
  123. _embedding_count_expression(inner) for inner in embeddings)
  124. def merge_multimodal_embeddings(input_ids: torch.Tensor,
  125. inputs_embeds: torch.Tensor,
  126. multimodal_embeddings: NestedTensors,
  127. placeholder_token_id: int) -> torch.Tensor:
  128. """
  129. Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
  130. positions in ``inputs_embeds`` corresponding to placeholder tokens in
  131. ``input_ids``.
  132. Note:
  133. This updates ``inputs_embeds`` in place.
  134. """
  135. mask = (input_ids == placeholder_token_id)
  136. num_expected_tokens = mask.sum().item()
  137. assert isinstance(num_expected_tokens, int)
  138. flattened = _flatten_embeddings(multimodal_embeddings)
  139. if flattened.shape[0] != num_expected_tokens:
  140. expr = _embedding_count_expression(multimodal_embeddings)
  141. raise ValueError(
  142. f"Attempted to assign {expr} = {flattened.shape[0]} "
  143. f"multimodal tokens to {num_expected_tokens} placeholders")
  144. inputs_embeds[mask] = flattened
  145. return inputs_embeds
  146. class LayerFn(Protocol):
  147. def __call__(
  148. self,
  149. prefix="",
  150. ) -> torch.nn.Module:
  151. ...
  152. class PPMissingLayer(torch.nn.Identity):
  153. """
  154. A placeholder layer for missing layers in a pipeline parallel model.
  155. """
  156. def __init__(self, *args, **kwargs):
  157. super().__init__()
  158. _CPU_OFFLOAD_BYTES = 0
  159. _CPU_OFFLOAD_MAX_BYTES = 0
  160. def set_cpu_offload_max_bytes(max_bytes: int) -> None:
  161. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  162. _CPU_OFFLOAD_BYTES = 0
  163. _CPU_OFFLOAD_MAX_BYTES = max_bytes
  164. def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
  165. device = next(module.parameters()).device
  166. if device == torch.device("cpu"):
  167. return module
  168. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  169. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  170. return module
  171. pin_memory = is_pin_memory_available()
  172. # offload parameters to CPU
  173. # use pin_memory if possible, which helps cudagraph capture speed
  174. offloaded_parameters = False
  175. for p in module.parameters():
  176. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  177. # we use per-parameter offloading
  178. # one module might have some parameters offloaded and some not
  179. break
  180. # `torch.empty_like` does not support `pin_memory` argument
  181. cpu_data = torch.empty_strided(size=p.data.size(),
  182. stride=p.data.stride(),
  183. dtype=p.data.dtype,
  184. layout=p.data.layout,
  185. device='cpu',
  186. pin_memory=pin_memory)
  187. cpu_data.copy_(p.data)
  188. p.data = cpu_data
  189. _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
  190. offloaded_parameters = True
  191. if offloaded_parameters:
  192. original_forward = module.forward
  193. def forward(*args, **kwargs):
  194. module.forward = original_forward
  195. device_state = {
  196. # here we blindly call `to(device)`
  197. # if the parameter is already on the device, it will be a no-op
  198. k: v.to(device, non_blocking=True)
  199. for k, v in module.state_dict().items()
  200. }
  201. output = functional_call(module,
  202. device_state,
  203. args=args,
  204. kwargs=kwargs)
  205. module.forward = forward
  206. return output
  207. module.forward = forward
  208. return module
  209. def make_layers(
  210. num_hidden_layers: int,
  211. layer_fn: LayerFn,
  212. prefix: str,
  213. ) -> Tuple[int, int, torch.nn.ModuleList]:
  214. """Make a list of layers with the given layer function, taking
  215. pipeline parallelism into account.
  216. """
  217. from aphrodite.distributed.parallel_state import get_pp_group
  218. from aphrodite.distributed.utils import get_pp_indices
  219. start_layer, end_layer = get_pp_indices(num_hidden_layers,
  220. get_pp_group().rank_in_group,
  221. get_pp_group().world_size)
  222. modules = torch.nn.ModuleList(
  223. [PPMissingLayer() for _ in range(start_layer)] + [
  224. maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
  225. for idx in range(start_layer, end_layer)
  226. ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
  227. return start_layer, end_layer, modules
  228. # NOTE: don't use lru_cache here because it can prevent garbage collection
  229. _model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
  230. def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
  231. """Get the names of the missing layers in a pipeline parallel model."""
  232. model_id = id(model)
  233. if model_id in _model_to_pp_missing_layer_names:
  234. return _model_to_pp_missing_layer_names[model_id]
  235. missing_layer_names = []
  236. for name, module in model.named_modules():
  237. if isinstance(module, PPMissingLayer):
  238. # NOTE: the trailing dot is used to match the prefix of the layer.
  239. # without the dot, we could match a layer that is not missing,
  240. # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
  241. missing_layer_names.append(name + '.')
  242. _model_to_pp_missing_layer_names[model_id] = missing_layer_names
  243. return missing_layer_names
  244. def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
  245. """Check if a parameter is missing in a pipeline parallel model."""
  246. for missing_layer_name in get_pp_missing_layer_names(model):
  247. if name.startswith(missing_layer_name):
  248. return True
  249. return False
  250. def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
  251. def make_empty_intermediate_tensors(
  252. batch_size: int, dtype: torch.dtype,
  253. device: torch.device) -> IntermediateTensors:
  254. return IntermediateTensors({
  255. key: torch.zeros((batch_size, hidden_size),
  256. dtype=dtype,
  257. device=device)
  258. for key in keys
  259. })
  260. return make_empty_intermediate_tensors