utils.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from typing import Dict, List, Protocol, Tuple
  2. import torch
  3. from torch.func import functional_call
  4. from aphrodite.common.utils import is_pin_memory_available
  5. from aphrodite.multimodal import BatchedTensors
  6. def merge_vision_embeddings(input_ids: torch.Tensor,
  7. inputs_embeds: torch.Tensor,
  8. vision_embeddings: BatchedTensors,
  9. image_token_id: int) -> torch.Tensor:
  10. """
  11. Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
  12. in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
  13. Note:
  14. This updates `inputs_embeds` in place.
  15. """
  16. mask = (input_ids == image_token_id)
  17. num_expected_tokens = mask.sum()
  18. if isinstance(vision_embeddings, torch.Tensor):
  19. batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
  20. total_tokens = batch_size * batch_tokens
  21. if num_expected_tokens != total_tokens:
  22. expr = f"{batch_size} x {batch_tokens}"
  23. raise ValueError(
  24. f"Attempted to assign {expr} = {total_tokens} "
  25. f"image tokens to {num_expected_tokens} placeholders")
  26. inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
  27. else:
  28. size_per_batch = [t.shape[0] for t in vision_embeddings]
  29. total_tokens = sum(size_per_batch)
  30. if num_expected_tokens != total_tokens:
  31. expr = ' + '.join(map(str, size_per_batch))
  32. raise ValueError(
  33. f"Attempted to assign {expr} = {total_tokens} "
  34. f"image tokens to {num_expected_tokens} placeholders")
  35. inputs_embeds[mask] = torch.cat(vision_embeddings)
  36. return inputs_embeds
  37. class LayerFn(Protocol):
  38. def __call__(
  39. self,
  40. prefix="",
  41. ) -> torch.nn.Module:
  42. ...
  43. class PPMissingLayer(torch.nn.Identity):
  44. """
  45. A placeholder layer for missing layers in a pipeline parallel model.
  46. """
  47. def __init__(self, *args, **kwargs):
  48. super().__init__()
  49. _CPU_OFFLOAD_BYTES = 0
  50. _CPU_OFFLOAD_MAX_BYTES = 0
  51. def set_cpu_offload_max_bytes(max_bytes: int) -> None:
  52. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  53. _CPU_OFFLOAD_BYTES = 0
  54. _CPU_OFFLOAD_MAX_BYTES = max_bytes
  55. def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
  56. device = next(module.parameters()).device
  57. if device == torch.device("cpu"):
  58. return module
  59. global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
  60. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  61. return module
  62. pin_memory = is_pin_memory_available()
  63. # offload parameters to CPU
  64. # use pin_memory if possible, which helps cudagraph capture speed
  65. for p in module.parameters():
  66. if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
  67. # we use per-parameter offloading
  68. # one module might have some parameters offloaded and some not
  69. break
  70. # `torch.empty_like` does not support `pin_memory` argument
  71. cpu_data = torch.empty(size=p.data.size(),
  72. dtype=p.data.dtype,
  73. layout=p.data.layout,
  74. device='cpu',
  75. pin_memory=pin_memory)
  76. cpu_data.copy_(p.data)
  77. p.data = cpu_data
  78. _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
  79. state_dict: Dict[str, torch.Tensor] = module.state_dict()
  80. original_forward = module.forward
  81. def forward(*args, **kwargs):
  82. module.forward = original_forward
  83. device_state = {
  84. # here we blindly call `to(device)`
  85. # if the parameter is already on the device, it will be a no-op
  86. k: v.to(device, non_blocking=True)
  87. for k, v in state_dict.items()
  88. }
  89. output = functional_call(module,
  90. device_state,
  91. args=args,
  92. kwargs=kwargs)
  93. module.forward = forward
  94. return output
  95. module.forward = forward
  96. return module
  97. def make_layers(
  98. num_hidden_layers: int,
  99. layer_fn: LayerFn,
  100. prefix: str,
  101. ) -> Tuple[int, int, torch.nn.ModuleList]:
  102. """Make a list of layers with the given layer function, taking
  103. pipeline parallelism into account.
  104. """
  105. from aphrodite.distributed.parallel_state import get_pp_group
  106. from aphrodite.distributed.utils import get_pp_indices
  107. start_layer, end_layer = get_pp_indices(num_hidden_layers,
  108. get_pp_group().rank_in_group,
  109. get_pp_group().world_size)
  110. modules = torch.nn.ModuleList(
  111. [PPMissingLayer() for _ in range(start_layer)] + [
  112. maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
  113. for idx in range(start_layer, end_layer)
  114. ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
  115. return start_layer, end_layer, modules
  116. # NOTE: don't use lru_cache here because it can prevent garbage collection
  117. _model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
  118. def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
  119. """Get the names of the missing layers in a pipeline parallel model."""
  120. model_id = id(model)
  121. if model_id in _model_to_pp_missing_layer_names:
  122. return _model_to_pp_missing_layer_names[model_id]
  123. missing_layer_names = []
  124. for name, module in model.named_modules():
  125. if isinstance(module, PPMissingLayer):
  126. # NOTE: the trailing dot is used to match the prefix of the layer.
  127. # without the dot, we could match a layer that is not missing,
  128. # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
  129. missing_layer_names.append(name + '.')
  130. _model_to_pp_missing_layer_names[model_id] = missing_layer_names
  131. return missing_layer_names
  132. def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
  133. """Check if a parameter is missing in a pipeline parallel model."""
  134. for missing_layer_name in get_pp_missing_layer_names(model):
  135. if name.startswith(missing_layer_name):
  136. return True
  137. return False