test_model_input.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import dataclasses
  2. from typing import List, Tuple, Type
  3. import torch
  4. from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
  5. from aphrodite.attention.backends.abstract import AttentionBackend
  6. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  7. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  8. from aphrodite.task_handler.embedding_model_runner import (
  9. ModelInputForGPUWithPoolingMetadata)
  10. from aphrodite.task_handler.model_runner import (
  11. ModelInputForGPUWithSamplingMetadata)
  12. from aphrodite.task_handler.multi_step_model_runner import StatefulModelInput
  13. class MockAttentionBackend(AttentionBackend):
  14. @staticmethod
  15. def get_name() -> str:
  16. raise NotImplementedError
  17. @staticmethod
  18. def get_impl_cls():
  19. raise NotImplementedError
  20. @staticmethod
  21. def get_metadata_cls() -> Type["AttentionMetadata"]:
  22. return AttentionMetadata
  23. @staticmethod
  24. def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
  25. raise AttentionMetadataBuilder
  26. @staticmethod
  27. def get_kv_cache_shape(
  28. num_blocks: int,
  29. block_size: int,
  30. num_kv_heads: int,
  31. head_size: int,
  32. ) -> Tuple[int, ...]:
  33. raise NotImplementedError
  34. @staticmethod
  35. def swap_blocks(
  36. src_kv_cache: torch.Tensor,
  37. dst_kv_cache: torch.Tensor,
  38. src_to_dst: torch.Tensor,
  39. ) -> None:
  40. pass
  41. @staticmethod
  42. def copy_blocks(
  43. kv_caches: List[torch.Tensor],
  44. src_to_dists: torch.Tensor,
  45. ) -> None:
  46. pass
  47. def test_model_runner_input():
  48. sampling_metadata = SamplingMetadata(
  49. ["seq_group"],
  50. "selected_token_indices",
  51. "categorized_sample_indices",
  52. "num_prompts",
  53. )
  54. attn_metadata = AttentionMetadata(
  55. num_prefills=1,
  56. num_prefill_tokens=2,
  57. num_decode_tokens=3,
  58. slot_mapping=torch.zeros(1),
  59. )
  60. model_input = ModelInputForGPUWithSamplingMetadata(
  61. input_tokens=torch.ones(10),
  62. input_positions=torch.ones(10),
  63. sampling_metadata=sampling_metadata,
  64. attn_metadata=attn_metadata)
  65. assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
  66. # Test round trip serialization.
  67. tensor_dict = model_input.as_broadcastable_tensor_dict()
  68. attn_backend = MockAttentionBackend()
  69. received_model_input = (
  70. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  71. tensor_dict, attn_backend=attn_backend))
  72. # Check that received copy has correct values.
  73. assert isinstance(received_model_input,
  74. ModelInputForGPUWithSamplingMetadata)
  75. assert received_model_input.input_tokens is not None
  76. assert (
  77. received_model_input.input_tokens == model_input.input_tokens).all()
  78. assert received_model_input.input_positions is not None
  79. assert (received_model_input.input_positions == model_input.input_positions
  80. ).all()
  81. assert received_model_input.multi_modal_kwargs is None
  82. assert (received_model_input.multi_modal_kwargs ==
  83. model_input.multi_modal_kwargs)
  84. assert received_model_input.lora_requests is None
  85. assert received_model_input.lora_requests == model_input.lora_requests
  86. assert received_model_input.lora_mapping is None
  87. assert received_model_input.lora_mapping == model_input.lora_mapping
  88. for field in dataclasses.fields(AttentionMetadata):
  89. assert getattr(received_model_input.attn_metadata, field.name,
  90. None) == getattr(attn_metadata, field.name, None)
  91. # For sampling metadata, only selected_token_indices is copied.
  92. assert (received_model_input.sampling_metadata.selected_token_indices ==
  93. sampling_metadata.selected_token_indices)
  94. assert received_model_input.sampling_metadata.seq_groups is None
  95. def test_embedding_model_runner_input():
  96. pooling_metadata = PoolingMetadata(
  97. seq_groups=[[0]],
  98. seq_data={},
  99. prompt_lens=[1],
  100. )
  101. attn_metadata = AttentionMetadata(
  102. num_prefills=1,
  103. num_prefill_tokens=2,
  104. num_decode_tokens=3,
  105. slot_mapping=torch.zeros(1),
  106. )
  107. model_input = ModelInputForGPUWithPoolingMetadata(
  108. input_tokens=torch.ones(10),
  109. input_positions=torch.ones(10),
  110. pooling_metadata=pooling_metadata,
  111. attn_metadata=attn_metadata)
  112. assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
  113. # Test round trip serialization.
  114. tensor_dict = model_input.as_broadcastable_tensor_dict()
  115. attn_backend = MockAttentionBackend()
  116. received_model_input = (
  117. ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
  118. tensor_dict, attn_backend=attn_backend))
  119. # Check that received copy has correct values.
  120. assert isinstance(received_model_input,
  121. ModelInputForGPUWithPoolingMetadata)
  122. assert received_model_input.input_tokens is not None
  123. assert (
  124. received_model_input.input_tokens == model_input.input_tokens).all()
  125. assert received_model_input.input_positions is not None
  126. assert (received_model_input.input_positions == model_input.input_positions
  127. ).all()
  128. assert received_model_input.multi_modal_kwargs is None
  129. assert (received_model_input.multi_modal_kwargs ==
  130. model_input.multi_modal_kwargs)
  131. assert received_model_input.lora_requests is None
  132. assert received_model_input.lora_requests == model_input.lora_requests
  133. assert received_model_input.lora_mapping is None
  134. assert received_model_input.lora_mapping == model_input.lora_mapping
  135. for field in dataclasses.fields(AttentionMetadata):
  136. assert getattr(received_model_input.attn_metadata, field.name,
  137. None) == getattr(attn_metadata, field.name, None)
  138. # Pooling metadata is not broadcast.
  139. assert received_model_input.pooling_metadata is None
  140. def test_multi_step_model_runner_input():
  141. sampling_metadata = SamplingMetadata(
  142. ["seq_group"],
  143. "selected_token_indices",
  144. "categorized_sample_indices",
  145. "num_prompts",
  146. )
  147. attn_metadata = AttentionMetadata(
  148. num_prefills=1,
  149. num_prefill_tokens=2,
  150. num_decode_tokens=3,
  151. slot_mapping=torch.zeros(1),
  152. )
  153. frozen_model_input = ModelInputForGPUWithSamplingMetadata(
  154. input_tokens=torch.ones(10),
  155. input_positions=torch.ones(10),
  156. sampling_metadata=sampling_metadata,
  157. attn_metadata=attn_metadata)
  158. model_input = StatefulModelInput(
  159. frozen_model_input=frozen_model_input,
  160. is_last_step=True,
  161. is_first_multi_step=False,
  162. current_step=4,
  163. last_sampled_token_ids=torch.ones((10, 1)),
  164. is_multi_step=True,
  165. num_queries=8,
  166. num_seqs=5,
  167. cached_outputs=[],
  168. )
  169. assert isinstance(model_input, StatefulModelInput)
  170. # Test round trip serialization.
  171. tensor_dict = model_input.as_broadcastable_tensor_dict()
  172. attn_backend = MockAttentionBackend()
  173. received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
  174. tensor_dict, attn_backend=attn_backend))
  175. receieved_frozen_input = received_model_input.frozen_model_input
  176. # Check that received copy has correct values.
  177. assert isinstance(received_model_input, StatefulModelInput)
  178. assert receieved_frozen_input.input_tokens is not None
  179. assert (receieved_frozen_input.input_tokens ==
  180. frozen_model_input.input_tokens).all()
  181. assert receieved_frozen_input.input_positions is not None
  182. assert (receieved_frozen_input.input_positions ==
  183. frozen_model_input.input_positions).all()
  184. assert receieved_frozen_input.multi_modal_kwargs is None
  185. assert (frozen_model_input.multi_modal_kwargs ==
  186. frozen_model_input.multi_modal_kwargs)
  187. assert receieved_frozen_input.lora_requests is None
  188. assert (receieved_frozen_input.lora_requests ==
  189. frozen_model_input.lora_requests)
  190. assert receieved_frozen_input.lora_mapping is None
  191. assert (
  192. receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
  193. for field in dataclasses.fields(AttentionMetadata):
  194. assert getattr(receieved_frozen_input.attn_metadata, field.name,
  195. None) == getattr(attn_metadata, field.name, None)
  196. # For sampling metadata, only selected_token_indices is copied.
  197. assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
  198. sampling_metadata.selected_token_indices)
  199. assert receieved_frozen_input.sampling_metadata.seq_groups is None
  200. # check non frozen fields
  201. assert received_model_input.is_last_step == model_input.is_last_step
  202. assert (received_model_input.is_first_multi_step ==
  203. model_input.is_first_multi_step)
  204. assert received_model_input.current_step == model_input.current_step
  205. assert (received_model_input.last_sampled_token_ids ==
  206. model_input.last_sampled_token_ids).all()
  207. assert received_model_input.is_multi_step == model_input.is_multi_step