pooler.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from enum import IntEnum
  2. import torch
  3. import torch.nn as nn
  4. from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
  5. PoolerOutput)
  6. from aphrodite.modeling.pooling_metadata import PoolingMetadata, PoolingTensors
  7. class PoolingType(IntEnum):
  8. """Enumeration for different types of pooling methods."""
  9. LAST = 0
  10. class Pooler(nn.Module):
  11. """A layer that pools specific information from hidden states.
  12. This layer does the following:
  13. 1. Extracts specific tokens or aggregates data based on pooling method.
  14. 2. Normalizes output if specified.
  15. 3. Returns structured results as `PoolerOutput`.
  16. Attributes:
  17. pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
  18. normalize: Whether to normalize the pooled data.
  19. """
  20. def __init__(self, pooling_type: PoolingType, normalize: bool):
  21. super().__init__()
  22. self.pooling_type = pooling_type
  23. self.normalize = normalize
  24. def forward(
  25. self,
  26. hidden_states: torch.Tensor,
  27. pooling_metadata: PoolingMetadata,
  28. ) -> PoolerOutput:
  29. """Pools specific information from hidden states based on metadata."""
  30. prompt_lens = PoolingTensors.from_pooling_metadata(
  31. pooling_metadata, hidden_states.device).prompt_lens
  32. if self.pooling_type == PoolingType.LAST:
  33. last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
  34. pooled_data = hidden_states[last_token_flat_indices]
  35. else:
  36. raise ValueError(f"Invalid pooling type: {self.pooling_type}")
  37. if self.normalize:
  38. pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
  39. pooled_outputs = [
  40. EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
  41. ]
  42. return PoolerOutput(outputs=pooled_outputs)