pooling_metadata.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from dataclasses import dataclass
  2. from typing import Any, Dict, List, Tuple
  3. import torch
  4. from aphrodite.common.pooling_params import PoolingParams
  5. from aphrodite.common.utils import is_pin_memory_available
  6. class PoolingMetadata:
  7. """Metadata for pooling operations in the Pooler layer.
  8. This class holds the necessary information for pooling operations,
  9. providing context for how to perform pooling and other related operations.
  10. Attributes:
  11. seq_groups: List of (seq_ids, pooling_params).
  12. seq_data: A mapping of sequence ID to additional sequence data.
  13. prompt_lens: List of the lengths of each prompt.
  14. """
  15. def __init__(
  16. self,
  17. seq_groups: List[Tuple[List[int], PoolingParams]],
  18. seq_data: Dict[int, Any], # Specific data related to sequences
  19. prompt_lens: List[int],
  20. ) -> None:
  21. self.seq_groups = seq_groups
  22. self.seq_data = seq_data
  23. self.prompt_lens = prompt_lens
  24. def __repr__(self) -> str:
  25. return ("PoolingMetadata("
  26. f"seq_groups={self.seq_groups}, "
  27. f"seq_data={self.seq_data}, "
  28. f"prompt_lens={self.prompt_lens})")
  29. @dataclass
  30. class PoolingTensors:
  31. """Tensors for pooling."""
  32. prompt_lens: torch.Tensor
  33. @classmethod
  34. def from_pooling_metadata(
  35. cls,
  36. pooling_metadata: "PoolingMetadata",
  37. device: torch.device,
  38. ) -> "PoolingTensors":
  39. """
  40. Create PoolingTensors from PoolingMetadata.
  41. Args:
  42. pooling_metadata: PoolingMetadata instance to convert.
  43. device: Device to store the tensors.
  44. """
  45. # Convert prompt lengths to tensor
  46. pin_memory = is_pin_memory_available()
  47. prompt_lens_t = torch.tensor(
  48. pooling_metadata.prompt_lens,
  49. device="cpu",
  50. dtype=torch.long,
  51. pin_memory=pin_memory,
  52. )
  53. return cls(prompt_lens=prompt_lens_t.to(device=device,
  54. non_blocking=True), )