utils.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
  2. Union, overload)
  3. import torch
  4. import torch.nn as nn
  5. from torch.func import functional_call
  6. from transformers import PretrainedConfig
  7. from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
  8. SchedulerConfig)
  9. from aphrodite.common.utils import is_pin_memory_available, progress_bar
  10. from aphrodite.modeling.model_loader.loader import build_model
  11. from aphrodite.modeling.models import ModelRegistry
  12. from aphrodite.multimodal.base import NestedTensors
  13. from aphrodite.quantization import QuantizationConfig
  14. def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
  15. """
  16. Helper function to load weights for inner aphrodite models.
  17. See also:
  18. :ref:`init_aphrodite_registered_model`
  19. """
  20. weights_list = list(weights)
  21. for name, loaded_weight in progress_bar(weights_list,
  22. desc="Loading modules..."):
  23. name = name.split(".")
  24. if prefix == name.pop(0):
  25. name = ".".join(name)
  26. yield name, loaded_weight
  27. def init_aphrodite_registered_model(
  28. hf_config: PretrainedConfig,
  29. cache_config: Optional[CacheConfig],
  30. quant_config: Optional[QuantizationConfig],
  31. *,
  32. lora_config: Optional[LoRAConfig] = None,
  33. multimodal_config: Optional[MultiModalConfig] = None,
  34. scheduler_config: Optional[SchedulerConfig] = None,
  35. ) -> nn.Module:
  36. """
  37. Helper function to initialize an inner model registered to aphrodite,
  38. based on the arguments passed to the outer aphrodite model.
  39. """
  40. model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
  41. return build_model(
  42. model_class,
  43. hf_config,
  44. cache_config,
  45. quant_config,
  46. lora_config=lora_config,
  47. multimodal_config=multimodal_config,
  48. scheduler_config=scheduler_config,
  49. )
  50. @overload
  51. def flatten_bn(x: torch.Tensor) -> torch.Tensor:
  52. ...
  53. @overload
  54. def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
  55. ...
  56. @overload
  57. def flatten_bn(
  58. x: Union[List[torch.Tensor], torch.Tensor],
  59. *,
  60. concat: Literal[True],
  61. ) -> torch.Tensor:
  62. ...
  63. def flatten_bn(
  64. x: Union[List[torch.Tensor], torch.Tensor],
  65. *,
  66. concat: bool = False,
  67. ) -> Union[List[torch.Tensor], torch.Tensor]:
  68. """
  69. Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
  70. The input tensor should have shape ``(B, N, ...)```.
  71. """
  72. if isinstance(x, torch.Tensor):
  73. return x.flatten(0, 1)
  74. if concat:
  75. return torch.cat(x)
  76. return [x_n for x_b in x for x_n in x_b]
  77. def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
  78. """
  79. Recursively flattens and concatenates NestedTensors on all but the last
  80. dimension.
  81. """
  82. if isinstance(embeddings, torch.Tensor):
  83. # Flatten all but the last dimension.
  84. return embeddings.flatten(0, -2)
  85. return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
  86. def _embedding_count_expression(embeddings: NestedTensors) -> str:
  87. """
  88. Constructs a debugging representation of the number of embeddings in the
  89. NestedTensors.
  90. """
  91. if isinstance(embeddings, torch.Tensor):
  92. return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
  93. return " + ".join(
  94. _embedding_count_expression(inner) for inner in embeddings)
  95. def merge_multimodal_embeddings(input_ids: torch.Tensor,
  96. inputs_embeds: torch.Tensor,
  97. multimodal_embeddings: NestedTensors,
  98. placeholder_token_id: int) -> torch.Tensor:
  99. """
  100. Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
  101. positions in ``inputs_embeds`` corresponding to placeholder tokens in
  102. ``input_ids``.
  103. Note:
  104. This updates ``inputs_embeds`` in place.
  105. """
  106. mask = (input_ids == placeholder_token_id)
  107. num_expected_tokens = mask.sum().item()
  108. assert isinstance(num_expected_tokens, int)
  109. flattened = _flatten_embeddings(multimodal_embeddings)
  110. if flattened.shape[0] != num_expected_tokens:
  111. expr = _embedding_count_expression(multimodal_embeddings)
  112. raise ValueError(
  113. f"Attempted to assign {expr} = {flattened.shape[0]} "
  114. f"multimodal tokens to {num_expected_tokens} placeholders")
  115. inputs_embeds[mask] = flattened
  116. return inputs_embeds
  117. class LayerFn(Protocol):
  118. def __call__(
  119. self,
  120. prefix="",
  121. ) -> torch.nn.Module:
  122. ...
  123. class PPMissingLayer(torch.nn.Identity):
  124. """
  125. A placeholder layer for missing layers in a pipeline parallel model.
  126. """
  127. def __init__(self, *args, **kwargs):
  128. super().__init__()
  129. _CPU_OFFLOAD_BYTES = 0
  130. _CPU_OFFLOAD_MAX_BYTES = 0
  131. def set_cpu_offload_max_bytes(max_bytes: int) -> None:
  132. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  133. _CPU_OFFLOAD_BYTES = 0
  134. _CPU_OFFLOAD_MAX_BYTES = max_bytes
  135. def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
  136. device = next(module.parameters()).device
  137. if device == torch.device("cpu"):
  138. return module
  139. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  140. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  141. return module
  142. pin_memory = is_pin_memory_available()
  143. # offload parameters to CPU
  144. # use pin_memory if possible, which helps cudagraph capture speed
  145. offloaded_parameters = False
  146. for p in module.parameters():
  147. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  148. # we use per-parameter offloading
  149. # one module might have some parameters offloaded and some not
  150. break
  151. # `torch.empty_like` does not support `pin_memory` argument
  152. cpu_data = torch.empty_strided(size=p.data.size(),
  153. stride=p.data.stride(),
  154. dtype=p.data.dtype,
  155. layout=p.data.layout,
  156. device='cpu',
  157. pin_memory=pin_memory)
  158. cpu_data.copy_(p.data)
  159. p.data = cpu_data
  160. _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
  161. offloaded_parameters = True
  162. if offloaded_parameters:
  163. original_forward = module.forward
  164. def forward(*args, **kwargs):
  165. module.forward = original_forward
  166. device_state = {
  167. # here we blindly call `to(device)`
  168. # if the parameter is already on the device, it will be a no-op
  169. k: v.to(device, non_blocking=True)
  170. for k, v in module.state_dict().items()
  171. }
  172. output = functional_call(module,
  173. device_state,
  174. args=args,
  175. kwargs=kwargs)
  176. module.forward = forward
  177. return output
  178. module.forward = forward
  179. return module
  180. def make_layers(
  181. num_hidden_layers: int,
  182. layer_fn: LayerFn,
  183. prefix: str,
  184. ) -> Tuple[int, int, torch.nn.ModuleList]:
  185. """Make a list of layers with the given layer function, taking
  186. pipeline parallelism into account.
  187. """
  188. from aphrodite.distributed.parallel_state import get_pp_group
  189. from aphrodite.distributed.utils import get_pp_indices
  190. start_layer, end_layer = get_pp_indices(num_hidden_layers,
  191. get_pp_group().rank_in_group,
  192. get_pp_group().world_size)
  193. modules = torch.nn.ModuleList(
  194. [PPMissingLayer() for _ in range(start_layer)] + [
  195. maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
  196. for idx in range(start_layer, end_layer)
  197. ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
  198. return start_layer, end_layer, modules
  199. # NOTE: don't use lru_cache here because it can prevent garbage collection
  200. _model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
  201. def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
  202. """Get the names of the missing layers in a pipeline parallel model."""
  203. model_id = id(model)
  204. if model_id in _model_to_pp_missing_layer_names:
  205. return _model_to_pp_missing_layer_names[model_id]
  206. missing_layer_names = []
  207. for name, module in model.named_modules():
  208. if isinstance(module, PPMissingLayer):
  209. # NOTE: the trailing dot is used to match the prefix of the layer.
  210. # without the dot, we could match a layer that is not missing,
  211. # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
  212. missing_layer_names.append(name + '.')
  213. _model_to_pp_missing_layer_names[model_id] = missing_layer_names
  214. return missing_layer_names
  215. def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
  216. """Check if a parameter is missing in a pipeline parallel model."""
  217. for missing_layer_name in get_pp_missing_layer_names(model):
  218. if name.startswith(missing_layer_name):
  219. return True
  220. return False