utils.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from typing import Dict, Iterable, List, Optional, Protocol, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from torch.func import functional_call
  5. from transformers import PretrainedConfig
  6. from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
  7. SchedulerConfig)
  8. from aphrodite.common.utils import is_pin_memory_available
  9. from aphrodite.modeling.model_loader.loader import build_model
  10. from aphrodite.modeling.models import ModelRegistry
  11. from aphrodite.multimodal import BatchedTensors
  12. from aphrodite.quantization import QuantizationConfig
  13. def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
  14. """
  15. Helper function to load weights for inner aphrodite models.
  16. See also:
  17. :ref:`init_aphrodite_registered_model`
  18. """
  19. for name, loaded_weight in weights:
  20. name = name.split(".")
  21. if prefix == name.pop(0):
  22. name = ".".join(name)
  23. yield name, loaded_weight
  24. def init_aphrodite_registered_model(
  25. hf_config: PretrainedConfig,
  26. cache_config: Optional[CacheConfig],
  27. quant_config: Optional[QuantizationConfig],
  28. *,
  29. lora_config: Optional[LoRAConfig] = None,
  30. multimodal_config: Optional[MultiModalConfig] = None,
  31. scheduler_config: Optional[SchedulerConfig] = None,
  32. ) -> nn.Module:
  33. """
  34. Helper function to initialize an inner model registered to aphrodite,
  35. based on the arguments passed to the outer aphrodite model.
  36. """
  37. model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
  38. return build_model(
  39. model_class,
  40. hf_config,
  41. cache_config,
  42. quant_config,
  43. lora_config=lora_config,
  44. multimodal_config=multimodal_config,
  45. scheduler_config=scheduler_config,
  46. )
  47. def merge_vision_embeddings(input_ids: torch.Tensor,
  48. inputs_embeds: torch.Tensor,
  49. vision_embeddings: BatchedTensors,
  50. image_token_id: int) -> torch.Tensor:
  51. """
  52. Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
  53. positions in ``inputs_embeds`` corresponding to placeholder image tokens in
  54. ``input_ids``.
  55. Note:
  56. This updates ``inputs_embeds`` in place.
  57. """
  58. mask = (input_ids == image_token_id)
  59. num_expected_tokens = mask.sum()
  60. if isinstance(vision_embeddings, torch.Tensor):
  61. batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
  62. total_tokens = batch_size * batch_tokens
  63. if num_expected_tokens != total_tokens:
  64. expr = f"{batch_size} x {batch_tokens}"
  65. raise ValueError(
  66. f"Attempted to assign {expr} = {total_tokens} "
  67. f"image tokens to {num_expected_tokens} placeholders")
  68. inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
  69. else:
  70. size_per_batch = [t.shape[0] for t in vision_embeddings]
  71. total_tokens = sum(size_per_batch)
  72. if num_expected_tokens != total_tokens:
  73. expr = ' + '.join(map(str, size_per_batch))
  74. raise ValueError(
  75. f"Attempted to assign {expr} = {total_tokens} "
  76. f"image tokens to {num_expected_tokens} placeholders")
  77. inputs_embeds[mask] = torch.cat(vision_embeddings)
  78. return inputs_embeds
  79. class LayerFn(Protocol):
  80. def __call__(
  81. self,
  82. prefix="",
  83. ) -> torch.nn.Module:
  84. ...
  85. class PPMissingLayer(torch.nn.Identity):
  86. """
  87. A placeholder layer for missing layers in a pipeline parallel model.
  88. """
  89. def __init__(self, *args, **kwargs):
  90. super().__init__()
  91. _CPU_OFFLOAD_BYTES = 0
  92. _CPU_OFFLOAD_MAX_BYTES = 0
  93. def set_cpu_offload_max_bytes(max_bytes: int) -> None:
  94. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  95. _CPU_OFFLOAD_BYTES = 0
  96. _CPU_OFFLOAD_MAX_BYTES = max_bytes
  97. def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
  98. device = next(module.parameters()).device
  99. if device == torch.device("cpu"):
  100. return module
  101. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  102. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  103. return module
  104. pin_memory = is_pin_memory_available()
  105. # offload parameters to CPU
  106. # use pin_memory if possible, which helps cudagraph capture speed
  107. offloaded_parameters = False
  108. for p in module.parameters():
  109. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  110. # we use per-parameter offloading
  111. # one module might have some parameters offloaded and some not
  112. break
  113. # `torch.empty_like` does not support `pin_memory` argument
  114. cpu_data = torch.empty_strided(size=p.data.size(),
  115. stride=p.data.stride(),
  116. dtype=p.data.dtype,
  117. layout=p.data.layout,
  118. device='cpu',
  119. pin_memory=pin_memory)
  120. cpu_data.copy_(p.data)
  121. p.data = cpu_data
  122. _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
  123. offloaded_parameters = True
  124. if offloaded_parameters:
  125. original_forward = module.forward
  126. def forward(*args, **kwargs):
  127. module.forward = original_forward
  128. device_state = {
  129. # here we blindly call `to(device)`
  130. # if the parameter is already on the device, it will be a no-op
  131. k: v.to(device, non_blocking=True)
  132. for k, v in module.state_dict().items()
  133. }
  134. output = functional_call(module,
  135. device_state,
  136. args=args,
  137. kwargs=kwargs)
  138. module.forward = forward
  139. return output
  140. module.forward = forward
  141. return module
  142. def make_layers(
  143. num_hidden_layers: int,
  144. layer_fn: LayerFn,
  145. prefix: str,
  146. ) -> Tuple[int, int, torch.nn.ModuleList]:
  147. """Make a list of layers with the given layer function, taking
  148. pipeline parallelism into account.
  149. """
  150. from aphrodite.distributed.parallel_state import get_pp_group
  151. from aphrodite.distributed.utils import get_pp_indices
  152. start_layer, end_layer = get_pp_indices(num_hidden_layers,
  153. get_pp_group().rank_in_group,
  154. get_pp_group().world_size)
  155. modules = torch.nn.ModuleList(
  156. [PPMissingLayer() for _ in range(start_layer)] + [
  157. maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
  158. for idx in range(start_layer, end_layer)
  159. ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
  160. return start_layer, end_layer, modules
  161. # NOTE: don't use lru_cache here because it can prevent garbage collection
  162. _model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
  163. def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
  164. """Get the names of the missing layers in a pipeline parallel model."""
  165. model_id = id(model)
  166. if model_id in _model_to_pp_missing_layer_names:
  167. return _model_to_pp_missing_layer_names[model_id]
  168. missing_layer_names = []
  169. for name, module in model.named_modules():
  170. if isinstance(module, PPMissingLayer):
  171. # NOTE: the trailing dot is used to match the prefix of the layer.
  172. # without the dot, we could match a layer that is not missing,
  173. # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
  174. missing_layer_names.append(name + '.')
  175. _model_to_pp_missing_layer_names[model_id] = missing_layer_names
  176. return missing_layer_names
  177. def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
  178. """Check if a parameter is missing in a pipeline parallel model."""
  179. for missing_layer_name in get_pp_missing_layer_names(model):
  180. if name.startswith(missing_layer_name):
  181. return True
  182. return False