123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Any, Dict, List, Tuple, Union
- import torch
- from torch import nn
- def split_decoder_layer_inputs(
- *args: Union[torch.Tensor, Any], **kwargs: Union[torch.Tensor, Any]
- ) -> Tuple[List[List[Any]], List[Dict[str, Any]]]:
- """This function splits batched decoder layer inputs into individual
- elements.
- Args:
- *args (Union[torch.Tensor, Any]): Positional arguments which could
- be a mix of tensors and other types.
- **kwargs (Union[torch.Tensor, Any]): Keyword arguments which could
- be a mix of tensors and other types.
- Returns:
- Tuple[List[List[Any]], List[Dict[str, Any]]]: A tuple containing two
- lists, one for positional arguments, one for keyword arguments.
- Each list contains individual elements from the batch.
- """
- if not isinstance(args[0], torch.Tensor):
- raise ValueError('The first argument must be a Tensor')
- bs = args[0].size(0)
- batch_args = []
- batch_kwargs = []
- for i in range(bs):
- new_args = []
- # Iterate over each argument. If it's a torch.Tensor and its first
- # dimension equals the batch size, then get the value corresponding
- # to the current index, else directly add the whole value.
- for val in args:
- if isinstance(val, torch.Tensor) and val.size(0) == bs:
- new_args.append(val[i:i + 1])
- else:
- new_args.append(val)
- new_kwargs = {}
- # Execute the same operation for the keyword arguments.
- for name, val in kwargs.items():
- if isinstance(val, torch.Tensor) and val.size(0) == bs:
- new_kwargs[name] = val[i:i + 1]
- else:
- new_kwargs[name] = val
- batch_args.append(new_args)
- batch_kwargs.append(new_kwargs)
- return batch_args, batch_kwargs
- def concat_decoder_layer_outputs(
- batch_outputs: List[Tuple[Any]]) -> Tuple[Any]:
- """This function concatenates individual decoder layer outputs into a
- batched output.
- Args:
- batch_outputs (List[Tuple[Any]]): A list of tuples, where each tuple
- represents the output from an individual element in the batch.
- Returns:
- Tuple[Any]: A tuple representing the batched output.
- """
- num_returns = len(batch_outputs[0])
- def is_past_key_value(data: Any) -> bool:
- """Check whether data is a past key-value pair.
- Args:
- data (Any): The data to check.
- Returns:
- bool: True if data is a past key-value pair, False otherwise.
- """
- flag = isinstance(data, tuple)
- flag = flag and len(data) == 2
- flag = flag and isinstance(data[0], torch.Tensor)
- flag = flag and isinstance(data[1], torch.Tensor)
- return flag
- new_outputs = []
- # Iterate over all types of return values.
- for i in range(num_returns):
- # Check if the current element is a past key-value pair.
- flag = is_past_key_value(batch_outputs[0][i])
- if flag:
- # Concatenate the keys and values separately.
- key = torch.cat([out[i][0] for out in batch_outputs])
- value = torch.cat([out[i][1] for out in batch_outputs])
- out_i = (key, value)
- else:
- # If it's not a past key-value pair, concatenate directly.
- out_i = torch.cat([out[i] for out in batch_outputs])
- new_outputs.append(out_i)
- return tuple(new_outputs)
- def collect_target_modules(
- model: nn.Module,
- # target: Union[str, type],
- target: str,
- skip_names: List[str] = None,
- prefix: str = '') -> Dict[str, nn.Module]:
- """Collects the specific target modules from the model.
- Args:
- model : The PyTorch module from which to collect the target modules.
- target : The specific target to be collected. It can be a class of a
- module or the name of a module.
- skip_names : List of names of modules to be skipped during collection.
- prefix : A string to be added as a prefix to the module names.
- Returns:
- A dictionary mapping from module names to module instances.
- """
- # if isinstance(target, LazyAttr):
- # target = target.build()
- if skip_names is None:
- skip_names = []
- if not isinstance(target, (type, str)):
- raise TypeError('Target must be a string (name of the module) '
- 'or a type (class of the module)')
- def _is_target(n, m):
- if isinstance(target, str):
- return target == type(m).__name__ and n not in skip_names
- return isinstance(m, target) and n not in skip_names
- name2mod = {}
- for name, mod in model.named_modules():
- m_name = f'{prefix}.{name}' if prefix else name
- if _is_target(name, mod):
- name2mod[m_name] = mod
- return name2mod
- def bimap_name_mod(
- name2mod_mappings: List[Dict[str, nn.Module]]
- ) -> Tuple[Dict[str, nn.Module], Dict[nn.Module, str]]:
- """Generates bidirectional maps from module names to module instances and
- vice versa.
- Args:
- name2mod_mappings : List of dictionaries each mapping from module
- names to module instances.
- Returns:
- Two dictionaries providing bidirectional mappings between module
- names and module instances.
- """
- name2mod = {}
- mod2name = {}
- for mapping in name2mod_mappings:
- mod2name.update({v: k for k, v in mapping.items()})
- name2mod.update(mapping)
- return name2mod, mod2name
|