observer.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, Union
  3. import torch
  4. from torch import nn
  5. class GlobalAvailMixin:
  6. """Mixin class to make instances globally available."""
  7. _instances: Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']] = {
  8. 'default': {}
  9. }
  10. def global_available(self,
  11. key: Union[str, nn.Module] = 'default',
  12. group: str = 'default') -> None:
  13. """Make the instance globally available.
  14. Args:
  15. key (Union[str, nn.Module], optional): Key to save the instance.
  16. Defaults to 'default'.
  17. group (str, optional): Group to save the instance.
  18. Defaults to 'default'.
  19. """
  20. self._save_instance(self, key, group)
  21. @classmethod
  22. def _save_instance(cls,
  23. instance: 'GlobalAvailMixin',
  24. key: Union[str, nn.Module] = 'default',
  25. group: str = 'default') -> None:
  26. """Save the instance.
  27. Args:
  28. instance (GlobalAvailMixin): Instance to save.
  29. key (Union[str, nn.Module], optional): Key to save the instance.
  30. Defaults to 'default'.
  31. group (str, optional): Group to save the instance.
  32. Defaults to 'default'.
  33. """
  34. if group not in cls._instances:
  35. assert isinstance(group, str)
  36. cls._instances[group] = {}
  37. cls._instances[group][key] = instance
  38. @classmethod
  39. def find(cls,
  40. key: Union[str, nn.Module] = 'default',
  41. group: str = 'default') -> Union[None, 'GlobalAvailMixin']:
  42. """Find an instance by its key and group.
  43. Args:
  44. key (Union[str, nn.Module], optional): Key of the instance.
  45. Defaults to 'default'.
  46. group (str, optional): Group of the instance.
  47. Defaults to 'default'.
  48. Returns:
  49. Union[None, GlobalAvailMixin]: The found instance, or None if
  50. it does not exist.
  51. """
  52. return cls._instances.get(group, {}).get(key)
  53. @classmethod
  54. def find_group(
  55. cls,
  56. group: str) -> Dict[Union[str, nn.Module], 'GlobalAvailMixin']:
  57. """Find all instances in a group.
  58. Args:
  59. group (str): Group of the instances.
  60. Returns:
  61. Dict[Union[str, nn.Module], GlobalAvailMixin]: All instances in
  62. the group.
  63. """
  64. return cls._instances.get(group, {})
  65. @classmethod
  66. def instances(
  67. cls) -> Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']]:
  68. """Get all instances."""
  69. return cls._instances
  70. class KVCacheObserver(GlobalAvailMixin):
  71. """A class to observe and record the max, min, and absolute max value of
  72. given tensor."""
  73. def __init__(self, num_head: int, head_dim: int) -> None:
  74. """Constructor for KVCacheObserver.
  75. Args:
  76. num_head : Number of heads
  77. head_dim : Dimension of each head
  78. """
  79. self.num_head = num_head
  80. self.head_dim = head_dim
  81. self.max_val = torch.full((num_head, head_dim),
  82. -torch.inf,
  83. dtype=torch.float16)
  84. self.min_val = torch.full((num_head, head_dim),
  85. torch.inf,
  86. dtype=torch.float16)
  87. self.absmax_val = torch.full((num_head, head_dim),
  88. 0,
  89. dtype=torch.float16)
  90. @torch.no_grad()
  91. def observe(self, x: torch.Tensor) -> None:
  92. """Function to observe the input tensor and update the max, min, and
  93. absolute max values.
  94. Args:
  95. x : Input tensor
  96. """
  97. assert len(x.shape) == 4
  98. if x.size(1) == self.num_head and x.size(3) == self.head_dim:
  99. # layout: (bs, heads, seqlen, dims)
  100. x = x.transpose(1, 2)
  101. elif x.size(2) != self.num_head or x.size(3) != self.head_dim:
  102. raise RuntimeError(
  103. 'Unexpected dimensions for x, expected (bs, num_head, '
  104. 'seqlen, head_dim) or (bs, seqlen, num_head, head_dim)')
  105. cur_max = x.flatten(0, 1).max(0)[0].cpu()
  106. cur_min = x.flatten(0, 1).min(0)[0].cpu()
  107. cur_absmax = x.flatten(0, 1).abs().max(0)[0].cpu()
  108. self.max_val = torch.maximum(self.max_val, cur_max)
  109. self.min_val = torch.minimum(self.min_val, cur_min)
  110. self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)
  111. class ActivationObserver(GlobalAvailMixin):
  112. """A class to observe and record the max, min, mean, absolute max, and
  113. absolute mean value of a given tensor.
  114. Also keeps track of the number of batches observed.
  115. """
  116. def __init__(self, dim: int) -> None:
  117. """Constructor for ActivationObserver.
  118. Args:
  119. dim : Dimension of the tensor
  120. """
  121. self.dim = dim
  122. self.max_val = torch.full((dim, ), -torch.inf, dtype=torch.float16)
  123. self.min_val = torch.full((dim, ), torch.inf, dtype=torch.float16)
  124. self.absmax_val = torch.full((dim, ), 0, dtype=torch.float16)
  125. self.absmean_val = torch.full((dim, ), 0, dtype=torch.float16)
  126. self.mean_val = torch.full((dim, ), 0, dtype=torch.float16)
  127. self.num_batches_tracked = 0
  128. @torch.no_grad()
  129. def observe(self, x: torch.Tensor) -> None:
  130. """Function to observe the input tensor and update the max, min, mean,
  131. absolute max, absolute mean values and number of batches tracked.
  132. Args:
  133. x : Input tensor
  134. """
  135. assert len(x.shape) == 3
  136. assert x.size(2) == self.dim
  137. cur_val = x.flatten(0, 1)
  138. cur_max = cur_val.max(0)[0].cpu()
  139. cur_min = cur_val.min(0)[0].cpu()
  140. cur_mean = cur_val.mean(0).cpu()
  141. cur_abs = cur_val.abs()
  142. cur_absmax = cur_abs.max(0)[0].cpu()
  143. cur_absmean = cur_abs.mean(0).cpu()
  144. self.max_val = torch.maximum(self.max_val, cur_max)
  145. self.min_val = torch.minimum(self.min_val, cur_min)
  146. self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)
  147. # Update mean and absmean value with accumulated sum divided
  148. # by total number of batches
  149. self.mean_val = (
  150. (self.mean_val * self.num_batches_tracked + cur_mean) /
  151. (self.num_batches_tracked + 1))
  152. self.absmean_val = (
  153. (self.absmean_val * self.num_batches_tracked + cur_absmean) /
  154. (self.num_batches_tracked + 1))
  155. # Increment the count of batches tracked
  156. self.num_batches_tracked += 1