utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Any, Dict, List, Tuple, Union
  3. import torch
  4. from torch import nn
  5. def split_decoder_layer_inputs(
  6. *args: Union[torch.Tensor, Any], **kwargs: Union[torch.Tensor, Any]
  7. ) -> Tuple[List[List[Any]], List[Dict[str, Any]]]:
  8. """This function splits batched decoder layer inputs into individual
  9. elements.
  10. Args:
  11. *args (Union[torch.Tensor, Any]): Positional arguments which could
  12. be a mix of tensors and other types.
  13. **kwargs (Union[torch.Tensor, Any]): Keyword arguments which could
  14. be a mix of tensors and other types.
  15. Returns:
  16. Tuple[List[List[Any]], List[Dict[str, Any]]]: A tuple containing two
  17. lists, one for positional arguments, one for keyword arguments.
  18. Each list contains individual elements from the batch.
  19. """
  20. if not isinstance(args[0], torch.Tensor):
  21. raise ValueError('The first argument must be a Tensor')
  22. bs = args[0].size(0)
  23. batch_args = []
  24. batch_kwargs = []
  25. for i in range(bs):
  26. new_args = []
  27. # Iterate over each argument. If it's a torch.Tensor and its first
  28. # dimension equals the batch size, then get the value corresponding
  29. # to the current index, else directly add the whole value.
  30. for val in args:
  31. if isinstance(val, torch.Tensor) and val.size(0) == bs:
  32. new_args.append(val[i:i + 1])
  33. else:
  34. new_args.append(val)
  35. new_kwargs = {}
  36. # Execute the same operation for the keyword arguments.
  37. for name, val in kwargs.items():
  38. if isinstance(val, torch.Tensor) and val.size(0) == bs:
  39. new_kwargs[name] = val[i:i + 1]
  40. else:
  41. new_kwargs[name] = val
  42. batch_args.append(new_args)
  43. batch_kwargs.append(new_kwargs)
  44. return batch_args, batch_kwargs
  45. def concat_decoder_layer_outputs(
  46. batch_outputs: List[Tuple[Any]]) -> Tuple[Any]:
  47. """This function concatenates individual decoder layer outputs into a
  48. batched output.
  49. Args:
  50. batch_outputs (List[Tuple[Any]]): A list of tuples, where each tuple
  51. represents the output from an individual element in the batch.
  52. Returns:
  53. Tuple[Any]: A tuple representing the batched output.
  54. """
  55. num_returns = len(batch_outputs[0])
  56. def is_past_key_value(data: Any) -> bool:
  57. """Check whether data is a past key-value pair.
  58. Args:
  59. data (Any): The data to check.
  60. Returns:
  61. bool: True if data is a past key-value pair, False otherwise.
  62. """
  63. flag = isinstance(data, tuple)
  64. flag = flag and len(data) == 2
  65. flag = flag and isinstance(data[0], torch.Tensor)
  66. flag = flag and isinstance(data[1], torch.Tensor)
  67. return flag
  68. new_outputs = []
  69. # Iterate over all types of return values.
  70. for i in range(num_returns):
  71. # Check if the current element is a past key-value pair.
  72. flag = is_past_key_value(batch_outputs[0][i])
  73. if flag:
  74. # Concatenate the keys and values separately.
  75. key = torch.cat([out[i][0] for out in batch_outputs])
  76. value = torch.cat([out[i][1] for out in batch_outputs])
  77. out_i = (key, value)
  78. else:
  79. # If it's not a past key-value pair, concatenate directly.
  80. out_i = torch.cat([out[i] for out in batch_outputs])
  81. new_outputs.append(out_i)
  82. return tuple(new_outputs)
  83. def collect_target_modules(
  84. model: nn.Module,
  85. # target: Union[str, type],
  86. target: str,
  87. skip_names: List[str] = None,
  88. prefix: str = '') -> Dict[str, nn.Module]:
  89. """Collects the specific target modules from the model.
  90. Args:
  91. model : The PyTorch module from which to collect the target modules.
  92. target : The specific target to be collected. It can be a class of a
  93. module or the name of a module.
  94. skip_names : List of names of modules to be skipped during collection.
  95. prefix : A string to be added as a prefix to the module names.
  96. Returns:
  97. A dictionary mapping from module names to module instances.
  98. """
  99. # if isinstance(target, LazyAttr):
  100. # target = target.build()
  101. if skip_names is None:
  102. skip_names = []
  103. if not isinstance(target, (type, str)):
  104. raise TypeError('Target must be a string (name of the module) '
  105. 'or a type (class of the module)')
  106. def _is_target(n, m):
  107. if isinstance(target, str):
  108. return target == type(m).__name__ and n not in skip_names
  109. return isinstance(m, target) and n not in skip_names
  110. name2mod = {}
  111. for name, mod in model.named_modules():
  112. m_name = f'{prefix}.{name}' if prefix else name
  113. if _is_target(name, mod):
  114. name2mod[m_name] = mod
  115. return name2mod
  116. def bimap_name_mod(
  117. name2mod_mappings: List[Dict[str, nn.Module]]
  118. ) -> Tuple[Dict[str, nn.Module], Dict[nn.Module, str]]:
  119. """Generates bidirectional maps from module names to module instances and
  120. vice versa.
  121. Args:
  122. name2mod_mappings : List of dictionaries each mapping from module
  123. names to module instances.
  124. Returns:
  125. Two dictionaries providing bidirectional mappings between module
  126. names and module instances.
  127. """
  128. name2mod = {}
  129. mod2name = {}
  130. for mapping in name2mod_mappings:
  131. mod2name.update({v: k for k, v in mapping.items()})
  132. name2mod.update(mapping)
  133. return name2mod, mod2name