123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, Union
- import torch
- from torch import nn
- class GlobalAvailMixin:
- """Mixin class to make instances globally available."""
- _instances: Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']] = {
- 'default': {}
- }
- def global_available(self,
- key: Union[str, nn.Module] = 'default',
- group: str = 'default') -> None:
- """Make the instance globally available.
- Args:
- key (Union[str, nn.Module], optional): Key to save the instance.
- Defaults to 'default'.
- group (str, optional): Group to save the instance.
- Defaults to 'default'.
- """
- self._save_instance(self, key, group)
- @classmethod
- def _save_instance(cls,
- instance: 'GlobalAvailMixin',
- key: Union[str, nn.Module] = 'default',
- group: str = 'default') -> None:
- """Save the instance.
- Args:
- instance (GlobalAvailMixin): Instance to save.
- key (Union[str, nn.Module], optional): Key to save the instance.
- Defaults to 'default'.
- group (str, optional): Group to save the instance.
- Defaults to 'default'.
- """
- if group not in cls._instances:
- assert isinstance(group, str)
- cls._instances[group] = {}
- cls._instances[group][key] = instance
- @classmethod
- def find(cls,
- key: Union[str, nn.Module] = 'default',
- group: str = 'default') -> Union[None, 'GlobalAvailMixin']:
- """Find an instance by its key and group.
- Args:
- key (Union[str, nn.Module], optional): Key of the instance.
- Defaults to 'default'.
- group (str, optional): Group of the instance.
- Defaults to 'default'.
- Returns:
- Union[None, GlobalAvailMixin]: The found instance, or None if
- it does not exist.
- """
- return cls._instances.get(group, {}).get(key)
- @classmethod
- def find_group(
- cls,
- group: str) -> Dict[Union[str, nn.Module], 'GlobalAvailMixin']:
- """Find all instances in a group.
- Args:
- group (str): Group of the instances.
- Returns:
- Dict[Union[str, nn.Module], GlobalAvailMixin]: All instances in
- the group.
- """
- return cls._instances.get(group, {})
- @classmethod
- def instances(
- cls) -> Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']]:
- """Get all instances."""
- return cls._instances
- class KVCacheObserver(GlobalAvailMixin):
- """A class to observe and record the max, min, and absolute max value of
- given tensor."""
- def __init__(self, num_head: int, head_dim: int) -> None:
- """Constructor for KVCacheObserver.
- Args:
- num_head : Number of heads
- head_dim : Dimension of each head
- """
- self.num_head = num_head
- self.head_dim = head_dim
- self.max_val = torch.full((num_head, head_dim),
- -torch.inf,
- dtype=torch.float16)
- self.min_val = torch.full((num_head, head_dim),
- torch.inf,
- dtype=torch.float16)
- self.absmax_val = torch.full((num_head, head_dim),
- 0,
- dtype=torch.float16)
- @torch.no_grad()
- def observe(self, x: torch.Tensor) -> None:
- """Function to observe the input tensor and update the max, min, and
- absolute max values.
- Args:
- x : Input tensor
- """
- assert len(x.shape) == 4
- if x.size(1) == self.num_head and x.size(3) == self.head_dim:
- # layout: (bs, heads, seqlen, dims)
- x = x.transpose(1, 2)
- elif x.size(2) != self.num_head or x.size(3) != self.head_dim:
- raise RuntimeError(
- 'Unexpected dimensions for x, expected (bs, num_head, seqlen, head_dim) or (bs, seqlen, num_head, head_dim)'
- )
- cur_max = x.flatten(0, 1).max(0)[0].cpu()
- cur_min = x.flatten(0, 1).min(0)[0].cpu()
- cur_absmax = x.flatten(0, 1).abs().max(0)[0].cpu()
- self.max_val = torch.maximum(self.max_val, cur_max)
- self.min_val = torch.minimum(self.min_val, cur_min)
- self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)
- class ActivationObserver(GlobalAvailMixin):
- """A class to observe and record the max, min, mean, absolute max, and
- absolute mean value of a given tensor.
- Also keeps track of the number of batches observed.
- """
- def __init__(self, dim: int) -> None:
- """Constructor for ActivationObserver.
- Args:
- dim : Dimension of the tensor
- """
- self.dim = dim
- self.max_val = torch.full((dim, ), -torch.inf, dtype=torch.float16)
- self.min_val = torch.full((dim, ), torch.inf, dtype=torch.float16)
- self.absmax_val = torch.full((dim, ), 0, dtype=torch.float16)
- self.absmean_val = torch.full((dim, ), 0, dtype=torch.float16)
- self.mean_val = torch.full((dim, ), 0, dtype=torch.float16)
- self.num_batches_tracked = 0
- @torch.no_grad()
- def observe(self, x: torch.Tensor) -> None:
- """Function to observe the input tensor and update the max, min, mean,
- absolute max, absolute mean values and number of batches tracked.
- Args:
- x : Input tensor
- """
- assert len(x.shape) == 3
- assert x.size(2) == self.dim
- cur_val = x.flatten(0, 1)
- cur_max = cur_val.max(0)[0].cpu()
- cur_min = cur_val.min(0)[0].cpu()
- cur_mean = cur_val.mean(0).cpu()
- cur_abs = cur_val.abs()
- cur_absmax = cur_abs.max(0)[0].cpu()
- cur_absmean = cur_abs.mean(0).cpu()
- self.max_val = torch.maximum(self.max_val, cur_max)
- self.min_val = torch.minimum(self.min_val, cur_min)
- self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)
- # Update mean and absmean value with accumulated sum divided
- # by total number of batches
- self.mean_val = (
- (self.mean_val * self.num_batches_tracked + cur_mean) /
- (self.num_batches_tracked + 1))
- self.absmean_val = (
- (self.absmean_val * self.num_batches_tracked + cur_absmean) /
- (self.num_batches_tracked + 1))
- # Increment the count of batches tracked
- self.num_batches_tracked += 1
|