123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- from dataclasses import dataclass
- from typing import Any, Dict, List, Tuple
- import torch
- from aphrodite.common.pooling_params import PoolingParams
- from aphrodite.common.utils import is_pin_memory_available
- class PoolingMetadata:
- """Metadata for pooling operations in the Pooler layer.
- This class holds the necessary information for pooling operations,
- providing context for how to perform pooling and other related operations.
- Attributes:
- seq_groups: List of (seq_ids, pooling_params).
- seq_data: A mapping of sequence ID to additional sequence data.
- prompt_lens: List of the lengths of each prompt.
- """
- def __init__(
- self,
- seq_groups: List[Tuple[List[int], PoolingParams]],
- seq_data: Dict[int, Any], # Specific data related to sequences
- prompt_lens: List[int],
- ) -> None:
- self.seq_groups = seq_groups
- self.seq_data = seq_data
- self.prompt_lens = prompt_lens
- def __repr__(self) -> str:
- return ("PoolingMetadata("
- f"seq_groups={self.seq_groups}, "
- f"seq_data={self.seq_data}, "
- f"prompt_lens={self.prompt_lens})")
- @dataclass
- class PoolingTensors:
- """Tensors for pooling."""
- prompt_lens: torch.Tensor
- @classmethod
- def from_pooling_metadata(
- cls,
- pooling_metadata: "PoolingMetadata",
- device: torch.device,
- ) -> "PoolingTensors":
- """
- Create PoolingTensors from PoolingMetadata.
- Args:
- pooling_metadata: PoolingMetadata instance to convert.
- device: Device to store the tensors.
- """
- # Convert prompt lengths to tensor
- pin_memory = is_pin_memory_available()
- prompt_lens_t = torch.tensor(
- pooling_metadata.prompt_lens,
- device="cpu",
- dtype=torch.long,
- pin_memory=pin_memory,
- )
- return cls(prompt_lens=prompt_lens_t.to(device=device,
- non_blocking=True), )
|