123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289 |
- from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
- Union, overload)
- import torch
- import torch.nn as nn
- from torch.func import functional_call
- from transformers import PretrainedConfig
- from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
- SchedulerConfig)
- from aphrodite.common.sequence import IntermediateTensors
- from aphrodite.common.utils import is_pin_memory_available, progress_bar
- from aphrodite.modeling.model_loader.loader import build_model
- from aphrodite.modeling.models import ModelRegistry
- from aphrodite.multimodal.base import NestedTensors
- from aphrodite.quantization import QuantizationConfig
- def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
- """
- Helper function to load weights for inner aphrodite models.
- See also:
- :ref:`init_aphrodite_registered_model`
- """
- weights_list = list(weights)
- for name, loaded_weight in progress_bar(weights_list,
- desc="Loading modules..."):
- name = name.split(".")
- if prefix == name.pop(0):
- name = ".".join(name)
- yield name, loaded_weight
- def init_aphrodite_registered_model(
- hf_config: PretrainedConfig,
- cache_config: Optional[CacheConfig],
- quant_config: Optional[QuantizationConfig],
- *,
- lora_config: Optional[LoRAConfig] = None,
- multimodal_config: Optional[MultiModalConfig] = None,
- scheduler_config: Optional[SchedulerConfig] = None,
- ) -> nn.Module:
- """
- Helper function to initialize an inner model registered to aphrodite,
- based on the arguments passed to the outer aphrodite model.
- """
- model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
- return build_model(
- model_class,
- hf_config,
- cache_config,
- quant_config,
- lora_config=lora_config,
- multimodal_config=multimodal_config,
- scheduler_config=scheduler_config,
- )
- @overload
- def flatten_bn(x: torch.Tensor) -> torch.Tensor:
- ...
- @overload
- def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
- ...
- @overload
- def flatten_bn(
- x: Union[List[torch.Tensor], torch.Tensor],
- *,
- concat: Literal[True],
- ) -> torch.Tensor:
- ...
- def flatten_bn(
- x: Union[List[torch.Tensor], torch.Tensor],
- *,
- concat: bool = False,
- ) -> Union[List[torch.Tensor], torch.Tensor]:
- """
- Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
- The input tensor should have shape ``(B, N, ...)```.
- """
- if isinstance(x, torch.Tensor):
- return x.flatten(0, 1)
- if concat:
- return torch.cat(x)
- return [x_n for x_b in x for x_n in x_b]
- def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
- """
- Recursively flattens and concatenates NestedTensors on all but the last
- dimension.
- """
- if isinstance(embeddings, torch.Tensor):
- # Flatten all but the last dimension.
- return embeddings.flatten(0, -2)
- return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
- def _embedding_count_expression(embeddings: NestedTensors) -> str:
- """
- Constructs a debugging representation of the number of embeddings in the
- NestedTensors.
- """
- if isinstance(embeddings, torch.Tensor):
- return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
- return " + ".join(
- _embedding_count_expression(inner) for inner in embeddings)
- def merge_multimodal_embeddings(input_ids: torch.Tensor,
- inputs_embeds: torch.Tensor,
- multimodal_embeddings: NestedTensors,
- placeholder_token_id: int) -> torch.Tensor:
- """
- Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
- positions in ``inputs_embeds`` corresponding to placeholder tokens in
- ``input_ids``.
- Note:
- This updates ``inputs_embeds`` in place.
- """
- mask = (input_ids == placeholder_token_id)
- num_expected_tokens = mask.sum().item()
- assert isinstance(num_expected_tokens, int)
- flattened = _flatten_embeddings(multimodal_embeddings)
- if flattened.shape[0] != num_expected_tokens:
- expr = _embedding_count_expression(multimodal_embeddings)
- raise ValueError(
- f"Attempted to assign {expr} = {flattened.shape[0]} "
- f"multimodal tokens to {num_expected_tokens} placeholders")
- inputs_embeds[mask] = flattened
- return inputs_embeds
- class LayerFn(Protocol):
- def __call__(
- self,
- prefix="",
- ) -> torch.nn.Module:
- ...
- class PPMissingLayer(torch.nn.Identity):
- """
- A placeholder layer for missing layers in a pipeline parallel model.
- """
- def __init__(self, *args, **kwargs):
- super().__init__()
- _CPU_OFFLOAD_BYTES = 0
- _CPU_OFFLOAD_MAX_BYTES = 0
- def set_cpu_offload_max_bytes(max_bytes: int) -> None:
- global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
- _CPU_OFFLOAD_BYTES = 0
- _CPU_OFFLOAD_MAX_BYTES = max_bytes
- def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
- device = next(module.parameters()).device
- if device == torch.device("cpu"):
- return module
- global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
- if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
- return module
- pin_memory = is_pin_memory_available()
- # offload parameters to CPU
- # use pin_memory if possible, which helps cudagraph capture speed
- offloaded_parameters = False
- for p in module.parameters():
- if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
- # we use per-parameter offloading
- # one module might have some parameters offloaded and some not
- break
- # `torch.empty_like` does not support `pin_memory` argument
- cpu_data = torch.empty_strided(size=p.data.size(),
- stride=p.data.stride(),
- dtype=p.data.dtype,
- layout=p.data.layout,
- device='cpu',
- pin_memory=pin_memory)
- cpu_data.copy_(p.data)
- p.data = cpu_data
- _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
- offloaded_parameters = True
- if offloaded_parameters:
- original_forward = module.forward
- def forward(*args, **kwargs):
- module.forward = original_forward
- device_state = {
- # here we blindly call `to(device)`
- # if the parameter is already on the device, it will be a no-op
- k: v.to(device, non_blocking=True)
- for k, v in module.state_dict().items()
- }
- output = functional_call(module,
- device_state,
- args=args,
- kwargs=kwargs)
- module.forward = forward
- return output
- module.forward = forward
- return module
- def make_layers(
- num_hidden_layers: int,
- layer_fn: LayerFn,
- prefix: str,
- ) -> Tuple[int, int, torch.nn.ModuleList]:
- """Make a list of layers with the given layer function, taking
- pipeline parallelism into account.
- """
- from aphrodite.distributed.parallel_state import get_pp_group
- from aphrodite.distributed.utils import get_pp_indices
- start_layer, end_layer = get_pp_indices(num_hidden_layers,
- get_pp_group().rank_in_group,
- get_pp_group().world_size)
- modules = torch.nn.ModuleList(
- [PPMissingLayer() for _ in range(start_layer)] + [
- maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
- for idx in range(start_layer, end_layer)
- ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
- return start_layer, end_layer, modules
- # NOTE: don't use lru_cache here because it can prevent garbage collection
- _model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
- def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
- """Get the names of the missing layers in a pipeline parallel model."""
- model_id = id(model)
- if model_id in _model_to_pp_missing_layer_names:
- return _model_to_pp_missing_layer_names[model_id]
- missing_layer_names = []
- for name, module in model.named_modules():
- if isinstance(module, PPMissingLayer):
- # NOTE: the trailing dot is used to match the prefix of the layer.
- # without the dot, we could match a layer that is not missing,
- # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
- missing_layer_names.append(name + '.')
- _model_to_pp_missing_layer_names[model_id] = missing_layer_names
- return missing_layer_names
- def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
- """Check if a parameter is missing in a pipeline parallel model."""
- for missing_layer_name in get_pp_missing_layer_names(model):
- if name.startswith(missing_layer_name):
- return True
- return False
- def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
- def make_empty_intermediate_tensors(
- batch_size: int, dtype: torch.dtype,
- device: torch.device) -> IntermediateTensors:
- return IntermediateTensors({
- key: torch.zeros((batch_size, hidden_size),
- dtype=dtype,
- device=device)
- for key in keys
- })
- return make_empty_intermediate_tensors
|