test_model_input.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import dataclasses
  2. from typing import List, Tuple, Type
  3. import torch
  4. from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
  5. from aphrodite.attention.backends.abstract import AttentionBackend
  6. from aphrodite.modeling import SamplingMetadata
  7. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  8. from aphrodite.task_handler.embedding_model_runner import (
  9. ModelInputForGPUWithPoolingMetadata)
  10. from aphrodite.task_handler.model_runner import (
  11. ModelInputForGPUWithSamplingMetadata)
  12. class MockAttentionBackend(AttentionBackend):
  13. @staticmethod
  14. def get_name() -> str:
  15. raise NotImplementedError
  16. @staticmethod
  17. def get_impl_cls():
  18. raise NotImplementedError
  19. @staticmethod
  20. def get_metadata_cls() -> Type["AttentionMetadata"]:
  21. return AttentionMetadata
  22. @staticmethod
  23. def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
  24. raise AttentionMetadataBuilder
  25. @staticmethod
  26. def get_kv_cache_shape(
  27. num_blocks: int,
  28. block_size: int,
  29. num_kv_heads: int,
  30. head_size: int,
  31. ) -> Tuple[int, ...]:
  32. raise NotImplementedError
  33. @staticmethod
  34. def swap_blocks(
  35. src_kv_cache: torch.Tensor,
  36. dst_kv_cache: torch.Tensor,
  37. src_to_dst: torch.Tensor,
  38. ) -> None:
  39. pass
  40. @staticmethod
  41. def copy_blocks(
  42. kv_caches: List[torch.Tensor],
  43. src_to_dists: torch.Tensor,
  44. ) -> None:
  45. pass
  46. def test_model_runner_input():
  47. sampling_metadata = SamplingMetadata(
  48. ["seq_group"],
  49. "selected_token_indices",
  50. "categorized_sample_indices",
  51. "num_prompts",
  52. )
  53. attn_metadata = AttentionMetadata(
  54. num_prefills=1,
  55. num_prefill_tokens=2,
  56. num_decode_tokens=3,
  57. slot_mapping=torch.zeros(1),
  58. )
  59. model_input = ModelInputForGPUWithSamplingMetadata(
  60. input_tokens=torch.ones(10),
  61. input_positions=torch.ones(10),
  62. sampling_metadata=sampling_metadata,
  63. attn_metadata=attn_metadata)
  64. assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
  65. # Test round trip serialization.
  66. tensor_dict = model_input.as_broadcastable_tensor_dict()
  67. attn_backend = MockAttentionBackend()
  68. received_model_input = (
  69. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  70. tensor_dict, attn_backend=attn_backend))
  71. # Check that received copy has correct values.
  72. assert isinstance(received_model_input,
  73. ModelInputForGPUWithSamplingMetadata)
  74. assert received_model_input.input_tokens is not None
  75. assert (
  76. received_model_input.input_tokens == model_input.input_tokens).all()
  77. assert received_model_input.input_positions is not None
  78. assert (received_model_input.input_positions == model_input.input_positions
  79. ).all()
  80. assert received_model_input.multi_modal_kwargs is None
  81. assert (received_model_input.multi_modal_kwargs ==
  82. model_input.multi_modal_kwargs)
  83. assert received_model_input.lora_requests is None
  84. assert received_model_input.lora_requests == model_input.lora_requests
  85. assert received_model_input.lora_mapping is None
  86. assert received_model_input.lora_mapping == model_input.lora_mapping
  87. for field in dataclasses.fields(AttentionMetadata):
  88. assert getattr(received_model_input.attn_metadata, field.name,
  89. None) == getattr(attn_metadata, field.name, None)
  90. # For sampling metadata, only selected_token_indices is copied.
  91. assert (received_model_input.sampling_metadata.selected_token_indices ==
  92. sampling_metadata.selected_token_indices)
  93. assert received_model_input.sampling_metadata.seq_groups is None
  94. def test_embedding_model_runner_input():
  95. pooling_metadata = PoolingMetadata(
  96. seq_groups=[[0]],
  97. seq_data={},
  98. prompt_lens=[1],
  99. )
  100. attn_metadata = AttentionMetadata(
  101. num_prefills=1,
  102. num_prefill_tokens=2,
  103. num_decode_tokens=3,
  104. slot_mapping=torch.zeros(1),
  105. )
  106. model_input = ModelInputForGPUWithPoolingMetadata(
  107. input_tokens=torch.ones(10),
  108. input_positions=torch.ones(10),
  109. pooling_metadata=pooling_metadata,
  110. attn_metadata=attn_metadata)
  111. assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
  112. # Test round trip serialization.
  113. tensor_dict = model_input.as_broadcastable_tensor_dict()
  114. attn_backend = MockAttentionBackend()
  115. received_model_input = (
  116. ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
  117. tensor_dict, attn_backend=attn_backend))
  118. # Check that received copy has correct values.
  119. assert isinstance(received_model_input,
  120. ModelInputForGPUWithPoolingMetadata)
  121. assert received_model_input.input_tokens is not None
  122. assert (
  123. received_model_input.input_tokens == model_input.input_tokens).all()
  124. assert received_model_input.input_positions is not None
  125. assert (received_model_input.input_positions == model_input.input_positions
  126. ).all()
  127. assert received_model_input.multi_modal_kwargs is None
  128. assert (received_model_input.multi_modal_kwargs ==
  129. model_input.multi_modal_kwargs)
  130. assert received_model_input.lora_requests is None
  131. assert received_model_input.lora_requests == model_input.lora_requests
  132. assert received_model_input.lora_mapping is None
  133. assert received_model_input.lora_mapping == model_input.lora_mapping
  134. for field in dataclasses.fields(AttentionMetadata):
  135. assert getattr(received_model_input.attn_metadata, field.name,
  136. None) == getattr(attn_metadata, field.name, None)
  137. # Pooling metadata is not broadcast.
  138. assert received_model_input.pooling_metadata is None