1
0

test_model_input.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import dataclasses
  2. from typing import List, Tuple, Type
  3. import torch
  4. from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
  5. from aphrodite.attention.backends.abstract import AttentionBackend
  6. from aphrodite.modeling import SamplingMetadata
  7. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  8. from aphrodite.task_handler.embedding_model_runner import (
  9. ModelInputForGPUWithPoolingMetadata)
  10. from aphrodite.task_handler.model_runner import (
  11. ModelInputForGPUWithSamplingMetadata)
  12. from aphrodite.task_handler.multi_step_model_runner import StatefulModelInput
  13. class MockAttentionBackend(AttentionBackend):
  14. @staticmethod
  15. def get_name() -> str:
  16. raise NotImplementedError
  17. @staticmethod
  18. def get_impl_cls():
  19. raise NotImplementedError
  20. @staticmethod
  21. def get_metadata_cls() -> Type["AttentionMetadata"]:
  22. return AttentionMetadata
  23. @staticmethod
  24. def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
  25. raise AttentionMetadataBuilder
  26. @staticmethod
  27. def get_kv_cache_shape(
  28. num_blocks: int,
  29. block_size: int,
  30. num_kv_heads: int,
  31. head_size: int,
  32. ) -> Tuple[int, ...]:
  33. raise NotImplementedError
  34. @staticmethod
  35. def swap_blocks(
  36. src_kv_cache: torch.Tensor,
  37. dst_kv_cache: torch.Tensor,
  38. src_to_dst: torch.Tensor,
  39. ) -> None:
  40. pass
  41. @staticmethod
  42. def copy_blocks(
  43. kv_caches: List[torch.Tensor],
  44. src_to_dists: torch.Tensor,
  45. ) -> None:
  46. pass
  47. def test_model_runner_input():
  48. sampling_metadata = SamplingMetadata(
  49. ["seq_group"],
  50. "selected_token_indices",
  51. "categorized_sample_indices",
  52. "num_prompts",
  53. )
  54. attn_metadata = AttentionMetadata(
  55. num_prefills=1,
  56. num_prefill_tokens=2,
  57. num_decode_tokens=3,
  58. slot_mapping=torch.zeros(1),
  59. )
  60. model_input = ModelInputForGPUWithSamplingMetadata(
  61. input_tokens=torch.ones(10),
  62. input_positions=torch.ones(10),
  63. sampling_metadata=sampling_metadata,
  64. attn_metadata=attn_metadata)
  65. assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
  66. # Test round trip serialization.
  67. tensor_dict = model_input.as_broadcastable_tensor_dict()
  68. attn_backend = MockAttentionBackend()
  69. received_model_input = (
  70. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  71. tensor_dict, attn_backend=attn_backend))
  72. # Check that received copy has correct values.
  73. assert isinstance(received_model_input,
  74. ModelInputForGPUWithSamplingMetadata)
  75. assert received_model_input.input_tokens is not None
  76. assert (
  77. received_model_input.input_tokens == model_input.input_tokens).all()
  78. assert received_model_input.input_positions is not None
  79. assert (received_model_input.input_positions == model_input.input_positions
  80. ).all()
  81. assert received_model_input.multi_modal_kwargs is None
  82. assert (received_model_input.multi_modal_kwargs ==
  83. model_input.multi_modal_kwargs)
  84. assert received_model_input.lora_requests is None
  85. assert received_model_input.lora_requests == model_input.lora_requests
  86. assert received_model_input.lora_mapping is None
  87. assert received_model_input.lora_mapping == model_input.lora_mapping
  88. for field in dataclasses.fields(AttentionMetadata):
  89. assert getattr(received_model_input.attn_metadata, field.name,
  90. None) == getattr(attn_metadata, field.name, None)
  91. # For sampling metadata, only selected_token_indices is copied.
  92. assert (received_model_input.sampling_metadata.selected_token_indices ==
  93. sampling_metadata.selected_token_indices)
  94. assert received_model_input.sampling_metadata.seq_groups is None
  95. def test_embedding_model_runner_input():
  96. pooling_metadata = PoolingMetadata(
  97. seq_groups=[[0]],
  98. seq_data={},
  99. prompt_lens=[1],
  100. )
  101. attn_metadata = AttentionMetadata(
  102. num_prefills=1,
  103. num_prefill_tokens=2,
  104. num_decode_tokens=3,
  105. slot_mapping=torch.zeros(1),
  106. )
  107. model_input = ModelInputForGPUWithPoolingMetadata(
  108. input_tokens=torch.ones(10),
  109. input_positions=torch.ones(10),
  110. pooling_metadata=pooling_metadata,
  111. attn_metadata=attn_metadata)
  112. assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
  113. # Test round trip serialization.
  114. tensor_dict = model_input.as_broadcastable_tensor_dict()
  115. attn_backend = MockAttentionBackend()
  116. received_model_input = (
  117. ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
  118. tensor_dict, attn_backend=attn_backend))
  119. # Check that received copy has correct values.
  120. assert isinstance(received_model_input,
  121. ModelInputForGPUWithPoolingMetadata)
  122. assert received_model_input.input_tokens is not None
  123. assert (
  124. received_model_input.input_tokens == model_input.input_tokens).all()
  125. assert received_model_input.input_positions is not None
  126. assert (received_model_input.input_positions == model_input.input_positions
  127. ).all()
  128. assert received_model_input.multi_modal_kwargs is None
  129. assert (received_model_input.multi_modal_kwargs ==
  130. model_input.multi_modal_kwargs)
  131. assert received_model_input.lora_requests is None
  132. assert received_model_input.lora_requests == model_input.lora_requests
  133. assert received_model_input.lora_mapping is None
  134. assert received_model_input.lora_mapping == model_input.lora_mapping
  135. for field in dataclasses.fields(AttentionMetadata):
  136. assert getattr(received_model_input.attn_metadata, field.name,
  137. None) == getattr(attn_metadata, field.name, None)
  138. # Pooling metadata is not broadcast.
  139. assert received_model_input.pooling_metadata is None
  140. def test_multi_step_model_runner_input():
  141. sampling_metadata = SamplingMetadata(
  142. ["seq_group"],
  143. "selected_token_indices",
  144. "categorized_sample_indices",
  145. "num_prompts",
  146. )
  147. attn_metadata = AttentionMetadata(
  148. num_prefills=1,
  149. num_prefill_tokens=2,
  150. num_decode_tokens=3,
  151. slot_mapping=torch.zeros(1),
  152. )
  153. frozen_model_input = ModelInputForGPUWithSamplingMetadata(
  154. input_tokens=torch.ones(10),
  155. input_positions=torch.ones(10),
  156. sampling_metadata=sampling_metadata,
  157. attn_metadata=attn_metadata)
  158. model_input = StatefulModelInput(
  159. frozen_model_input=frozen_model_input,
  160. is_last_step=True,
  161. is_first_multi_step=False,
  162. current_step=4,
  163. last_sampled_token_ids=torch.ones((10, 1)),
  164. is_multi_step=True,
  165. num_queries=8,
  166. num_seqs=5,
  167. cached_outputs=[],
  168. )
  169. assert isinstance(model_input, StatefulModelInput)
  170. # Test round trip serialization.
  171. tensor_dict = model_input.as_broadcastable_tensor_dict()
  172. attn_backend = MockAttentionBackend()
  173. received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
  174. tensor_dict, attn_backend=attn_backend))
  175. receieved_frozen_input = received_model_input.frozen_model_input
  176. # Check that received copy has correct values.
  177. assert isinstance(received_model_input, StatefulModelInput)
  178. assert receieved_frozen_input.input_tokens is not None
  179. assert (receieved_frozen_input.input_tokens ==
  180. frozen_model_input.input_tokens).all()
  181. assert receieved_frozen_input.input_positions is not None
  182. assert (receieved_frozen_input.input_positions ==
  183. frozen_model_input.input_positions).all()
  184. assert receieved_frozen_input.multi_modal_kwargs is None
  185. assert (frozen_model_input.multi_modal_kwargs ==
  186. frozen_model_input.multi_modal_kwargs)
  187. assert receieved_frozen_input.lora_requests is None
  188. assert (receieved_frozen_input.lora_requests ==
  189. frozen_model_input.lora_requests)
  190. assert receieved_frozen_input.lora_mapping is None
  191. assert (
  192. receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
  193. for field in dataclasses.fields(AttentionMetadata):
  194. assert getattr(receieved_frozen_input.attn_metadata, field.name,
  195. None) == getattr(attn_metadata, field.name, None)
  196. # For sampling metadata, only selected_token_indices is copied.
  197. assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
  198. sampling_metadata.selected_token_indices)
  199. assert receieved_frozen_input.sampling_metadata.seq_groups is None
  200. # check non frozen fields
  201. assert received_model_input.is_last_step == model_input.is_last_step
  202. assert (received_model_input.is_first_multi_step ==
  203. model_input.is_first_multi_step)
  204. assert received_model_input.current_step == model_input.current_step
  205. assert (received_model_input.last_sampled_token_ids ==
  206. model_input.last_sampled_token_ids).all()
  207. assert received_model_input.is_multi_step == model_input.is_multi_step