utils.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
  2. Union, overload)
  3. import torch
  4. import torch.nn as nn
  5. from torch.func import functional_call
  6. from transformers import PretrainedConfig
  7. from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
  8. SchedulerConfig)
  9. from aphrodite.common.sequence import IntermediateTensors
  10. from aphrodite.common.utils import is_pin_memory_available, progress_bar
  11. from aphrodite.modeling.model_loader.loader import build_model
  12. from aphrodite.modeling.models import ModelRegistry
  13. from aphrodite.multimodal.base import NestedTensors
  14. from aphrodite.quantization import QuantizationConfig
  15. def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
  16. """
  17. Helper function to load weights for inner aphrodite models.
  18. See also:
  19. :ref:`init_aphrodite_registered_model`
  20. """
  21. weights_list = list(weights)
  22. for name, loaded_weight in progress_bar(weights_list,
  23. desc="Loading modules..."):
  24. name = name.split(".")
  25. if prefix == name.pop(0):
  26. name = ".".join(name)
  27. yield name, loaded_weight
  28. def init_aphrodite_registered_model(
  29. hf_config: PretrainedConfig,
  30. cache_config: Optional[CacheConfig],
  31. quant_config: Optional[QuantizationConfig],
  32. *,
  33. lora_config: Optional[LoRAConfig] = None,
  34. multimodal_config: Optional[MultiModalConfig] = None,
  35. scheduler_config: Optional[SchedulerConfig] = None,
  36. ) -> nn.Module:
  37. """
  38. Helper function to initialize an inner model registered to aphrodite,
  39. based on the arguments passed to the outer aphrodite model.
  40. """
  41. model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
  42. return build_model(
  43. model_class,
  44. hf_config,
  45. cache_config,
  46. quant_config,
  47. lora_config=lora_config,
  48. multimodal_config=multimodal_config,
  49. scheduler_config=scheduler_config,
  50. )
  51. @overload
  52. def flatten_bn(x: torch.Tensor) -> torch.Tensor:
  53. ...
  54. @overload
  55. def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
  56. ...
  57. @overload
  58. def flatten_bn(
  59. x: Union[List[torch.Tensor], torch.Tensor],
  60. *,
  61. concat: Literal[True],
  62. ) -> torch.Tensor:
  63. ...
  64. def flatten_bn(
  65. x: Union[List[torch.Tensor], torch.Tensor],
  66. *,
  67. concat: bool = False,
  68. ) -> Union[List[torch.Tensor], torch.Tensor]:
  69. """
  70. Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
  71. The input tensor should have shape ``(B, N, ...)```.
  72. """
  73. if isinstance(x, torch.Tensor):
  74. return x.flatten(0, 1)
  75. if concat:
  76. return torch.cat(x)
  77. return [x_n for x_b in x for x_n in x_b]
  78. def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
  79. """
  80. Recursively flattens and concatenates NestedTensors on all but the last
  81. dimension.
  82. """
  83. if isinstance(embeddings, torch.Tensor):
  84. # Flatten all but the last dimension.
  85. return embeddings.flatten(0, -2)
  86. return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
  87. def _embedding_count_expression(embeddings: NestedTensors) -> str:
  88. """
  89. Constructs a debugging representation of the number of embeddings in the
  90. NestedTensors.
  91. """
  92. if isinstance(embeddings, torch.Tensor):
  93. return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
  94. return " + ".join(
  95. _embedding_count_expression(inner) for inner in embeddings)
  96. def merge_multimodal_embeddings(input_ids: torch.Tensor,
  97. inputs_embeds: torch.Tensor,
  98. multimodal_embeddings: NestedTensors,
  99. placeholder_token_id: int) -> torch.Tensor:
  100. """
  101. Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
  102. positions in ``inputs_embeds`` corresponding to placeholder tokens in
  103. ``input_ids``.
  104. Note:
  105. This updates ``inputs_embeds`` in place.
  106. """
  107. mask = (input_ids == placeholder_token_id)
  108. num_expected_tokens = mask.sum().item()
  109. assert isinstance(num_expected_tokens, int)
  110. flattened = _flatten_embeddings(multimodal_embeddings)
  111. if flattened.shape[0] != num_expected_tokens:
  112. expr = _embedding_count_expression(multimodal_embeddings)
  113. raise ValueError(
  114. f"Attempted to assign {expr} = {flattened.shape[0]} "
  115. f"multimodal tokens to {num_expected_tokens} placeholders")
  116. inputs_embeds[mask] = flattened
  117. return inputs_embeds
  118. class LayerFn(Protocol):
  119. def __call__(
  120. self,
  121. prefix="",
  122. ) -> torch.nn.Module:
  123. ...
  124. class PPMissingLayer(torch.nn.Identity):
  125. """
  126. A placeholder layer for missing layers in a pipeline parallel model.
  127. """
  128. def __init__(self, *args, **kwargs):
  129. super().__init__()
  130. _CPU_OFFLOAD_BYTES = 0
  131. _CPU_OFFLOAD_MAX_BYTES = 0
  132. def set_cpu_offload_max_bytes(max_bytes: int) -> None:
  133. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  134. _CPU_OFFLOAD_BYTES = 0
  135. _CPU_OFFLOAD_MAX_BYTES = max_bytes
  136. def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
  137. device = next(module.parameters()).device
  138. if device == torch.device("cpu"):
  139. return module
  140. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  141. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  142. return module
  143. pin_memory = is_pin_memory_available()
  144. # offload parameters to CPU
  145. # use pin_memory if possible, which helps cudagraph capture speed
  146. offloaded_parameters = False
  147. for p in module.parameters():
  148. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  149. # we use per-parameter offloading
  150. # one module might have some parameters offloaded and some not
  151. break
  152. # `torch.empty_like` does not support `pin_memory` argument
  153. cpu_data = torch.empty_strided(size=p.data.size(),
  154. stride=p.data.stride(),
  155. dtype=p.data.dtype,
  156. layout=p.data.layout,
  157. device='cpu',
  158. pin_memory=pin_memory)
  159. cpu_data.copy_(p.data)
  160. p.data = cpu_data
  161. _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
  162. offloaded_parameters = True
  163. if offloaded_parameters:
  164. original_forward = module.forward
  165. def forward(*args, **kwargs):
  166. module.forward = original_forward
  167. device_state = {
  168. # here we blindly call `to(device)`
  169. # if the parameter is already on the device, it will be a no-op
  170. k: v.to(device, non_blocking=True)
  171. for k, v in module.state_dict().items()
  172. }
  173. output = functional_call(module,
  174. device_state,
  175. args=args,
  176. kwargs=kwargs)
  177. module.forward = forward
  178. return output
  179. module.forward = forward
  180. return module
  181. def make_layers(
  182. num_hidden_layers: int,
  183. layer_fn: LayerFn,
  184. prefix: str,
  185. ) -> Tuple[int, int, torch.nn.ModuleList]:
  186. """Make a list of layers with the given layer function, taking
  187. pipeline parallelism into account.
  188. """
  189. from aphrodite.distributed.parallel_state import get_pp_group
  190. from aphrodite.distributed.utils import get_pp_indices
  191. start_layer, end_layer = get_pp_indices(num_hidden_layers,
  192. get_pp_group().rank_in_group,
  193. get_pp_group().world_size)
  194. modules = torch.nn.ModuleList(
  195. [PPMissingLayer() for _ in range(start_layer)] + [
  196. maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
  197. for idx in range(start_layer, end_layer)
  198. ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
  199. return start_layer, end_layer, modules
  200. # NOTE: don't use lru_cache here because it can prevent garbage collection
  201. _model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
  202. def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
  203. """Get the names of the missing layers in a pipeline parallel model."""
  204. model_id = id(model)
  205. if model_id in _model_to_pp_missing_layer_names:
  206. return _model_to_pp_missing_layer_names[model_id]
  207. missing_layer_names = []
  208. for name, module in model.named_modules():
  209. if isinstance(module, PPMissingLayer):
  210. # NOTE: the trailing dot is used to match the prefix of the layer.
  211. # without the dot, we could match a layer that is not missing,
  212. # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
  213. missing_layer_names.append(name + '.')
  214. _model_to_pp_missing_layer_names[model_id] = missing_layer_names
  215. return missing_layer_names
  216. def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
  217. """Check if a parameter is missing in a pipeline parallel model."""
  218. for missing_layer_name in get_pp_missing_layer_names(model):
  219. if name.startswith(missing_layer_name):
  220. return True
  221. return False
  222. def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
  223. def make_empty_intermediate_tensors(
  224. batch_size: int, dtype: torch.dtype,
  225. device: torch.device) -> IntermediateTensors:
  226. return IntermediateTensors({
  227. key: torch.zeros((batch_size, hidden_size),
  228. dtype=dtype,
  229. device=device)
  230. for key in keys
  231. })
  232. return make_empty_intermediate_tensors