utils.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from typing import Dict, Iterable, List, Optional, Protocol, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from torch.func import functional_call
  5. from transformers import PretrainedConfig
  6. from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
  7. SchedulerConfig)
  8. from aphrodite.common.utils import is_pin_memory_available, progress_bar
  9. from aphrodite.modeling.model_loader.loader import build_model
  10. from aphrodite.modeling.models import ModelRegistry
  11. from aphrodite.multimodal import BatchedTensors
  12. from aphrodite.quantization import QuantizationConfig
  13. def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
  14. """
  15. Helper function to load weights for inner aphrodite models.
  16. See also:
  17. :ref:`init_aphrodite_registered_model`
  18. """
  19. weights_list = list(weights)
  20. for name, loaded_weight in progress_bar(weights_list,
  21. desc="Loading modules..."):
  22. name = name.split(".")
  23. if prefix == name.pop(0):
  24. name = ".".join(name)
  25. yield name, loaded_weight
  26. def init_aphrodite_registered_model(
  27. hf_config: PretrainedConfig,
  28. cache_config: Optional[CacheConfig],
  29. quant_config: Optional[QuantizationConfig],
  30. *,
  31. lora_config: Optional[LoRAConfig] = None,
  32. multimodal_config: Optional[MultiModalConfig] = None,
  33. scheduler_config: Optional[SchedulerConfig] = None,
  34. ) -> nn.Module:
  35. """
  36. Helper function to initialize an inner model registered to aphrodite,
  37. based on the arguments passed to the outer aphrodite model.
  38. """
  39. model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
  40. return build_model(
  41. model_class,
  42. hf_config,
  43. cache_config,
  44. quant_config,
  45. lora_config=lora_config,
  46. multimodal_config=multimodal_config,
  47. scheduler_config=scheduler_config,
  48. )
  49. def merge_multimodal_embeddings(input_ids: torch.Tensor,
  50. inputs_embeds: torch.Tensor,
  51. multimodal_embeddings: BatchedTensors,
  52. placeholder_token_id: int) -> torch.Tensor:
  53. """
  54. Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
  55. positions in ``inputs_embeds`` corresponding to placeholder tokens in
  56. ``input_ids``.
  57. Note:
  58. This updates ``inputs_embeds`` in place.
  59. """
  60. mask = (input_ids == placeholder_token_id)
  61. num_expected_tokens = mask.sum()
  62. if isinstance(multimodal_embeddings, torch.Tensor):
  63. batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
  64. total_tokens = batch_size * batch_tokens
  65. if num_expected_tokens != total_tokens:
  66. expr = f"{batch_size} x {batch_tokens}"
  67. raise ValueError(
  68. f"Attempted to assign {expr} = {total_tokens} "
  69. f"multimodal tokens to {num_expected_tokens} placeholders")
  70. inputs_embeds[mask] = multimodal_embeddings.view(
  71. total_tokens, embed_dim)
  72. else:
  73. size_per_batch = [t.shape[0] for t in multimodal_embeddings]
  74. total_tokens = sum(size_per_batch)
  75. if num_expected_tokens != total_tokens:
  76. expr = ' + '.join(map(str, size_per_batch))
  77. raise ValueError(
  78. f"Attempted to assign {expr} = {total_tokens} "
  79. f"multimodal tokens to {num_expected_tokens} placeholders")
  80. inputs_embeds[mask] = torch.cat(multimodal_embeddings)
  81. return inputs_embeds
  82. class LayerFn(Protocol):
  83. def __call__(
  84. self,
  85. prefix="",
  86. ) -> torch.nn.Module:
  87. ...
  88. class PPMissingLayer(torch.nn.Identity):
  89. """
  90. A placeholder layer for missing layers in a pipeline parallel model.
  91. """
  92. def __init__(self, *args, **kwargs):
  93. super().__init__()
  94. _CPU_OFFLOAD_BYTES = 0
  95. _CPU_OFFLOAD_MAX_BYTES = 0
  96. def set_cpu_offload_max_bytes(max_bytes: int) -> None:
  97. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  98. _CPU_OFFLOAD_BYTES = 0
  99. _CPU_OFFLOAD_MAX_BYTES = max_bytes
  100. def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
  101. device = next(module.parameters()).device
  102. if device == torch.device("cpu"):
  103. return module
  104. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  105. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  106. return module
  107. pin_memory = is_pin_memory_available()
  108. # offload parameters to CPU
  109. # use pin_memory if possible, which helps cudagraph capture speed
  110. offloaded_parameters = False
  111. for p in module.parameters():
  112. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  113. # we use per-parameter offloading
  114. # one module might have some parameters offloaded and some not
  115. break
  116. # `torch.empty_like` does not support `pin_memory` argument
  117. cpu_data = torch.empty_strided(size=p.data.size(),
  118. stride=p.data.stride(),
  119. dtype=p.data.dtype,
  120. layout=p.data.layout,
  121. device='cpu',
  122. pin_memory=pin_memory)
  123. cpu_data.copy_(p.data)
  124. p.data = cpu_data
  125. _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
  126. offloaded_parameters = True
  127. if offloaded_parameters:
  128. original_forward = module.forward
  129. def forward(*args, **kwargs):
  130. module.forward = original_forward
  131. device_state = {
  132. # here we blindly call `to(device)`
  133. # if the parameter is already on the device, it will be a no-op
  134. k: v.to(device, non_blocking=True)
  135. for k, v in module.state_dict().items()
  136. }
  137. output = functional_call(module,
  138. device_state,
  139. args=args,
  140. kwargs=kwargs)
  141. module.forward = forward
  142. return output
  143. module.forward = forward
  144. return module
  145. def make_layers(
  146. num_hidden_layers: int,
  147. layer_fn: LayerFn,
  148. prefix: str,
  149. ) -> Tuple[int, int, torch.nn.ModuleList]:
  150. """Make a list of layers with the given layer function, taking
  151. pipeline parallelism into account.
  152. """
  153. from aphrodite.distributed.parallel_state import get_pp_group
  154. from aphrodite.distributed.utils import get_pp_indices
  155. start_layer, end_layer = get_pp_indices(num_hidden_layers,
  156. get_pp_group().rank_in_group,
  157. get_pp_group().world_size)
  158. modules = torch.nn.ModuleList(
  159. [PPMissingLayer() for _ in range(start_layer)] + [
  160. maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
  161. for idx in range(start_layer, end_layer)
  162. ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
  163. return start_layer, end_layer, modules
  164. # NOTE: don't use lru_cache here because it can prevent garbage collection
  165. _model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
  166. def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
  167. """Get the names of the missing layers in a pipeline parallel model."""
  168. model_id = id(model)
  169. if model_id in _model_to_pp_missing_layer_names:
  170. return _model_to_pp_missing_layer_names[model_id]
  171. missing_layer_names = []
  172. for name, module in model.named_modules():
  173. if isinstance(module, PPMissingLayer):
  174. # NOTE: the trailing dot is used to match the prefix of the layer.
  175. # without the dot, we could match a layer that is not missing,
  176. # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
  177. missing_layer_names.append(name + '.')
  178. _model_to_pp_missing_layer_names[model_id] = missing_layer_names
  179. return missing_layer_names
  180. def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
  181. """Check if a parameter is missing in a pipeline parallel model."""
  182. for missing_layer_name in get_pp_missing_layer_names(model):
  183. if name.startswith(missing_layer_name):
  184. return True
  185. return False