test_model_input.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import dataclasses
  2. from typing import List, Tuple, Type
  3. import torch
  4. from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
  5. from aphrodite.attention.backends.abstract import AttentionBackend
  6. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  7. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  8. from aphrodite.worker.embedding_model_runner import (
  9. ModelInputForGPUWithPoolingMetadata)
  10. from aphrodite.worker.model_runner import ModelInputForGPUWithSamplingMetadata
  11. from aphrodite.worker.multi_step_model_runner import StatefulModelInput
  12. class MockAttentionBackend(AttentionBackend):
  13. @staticmethod
  14. def get_name() -> str:
  15. raise NotImplementedError
  16. @staticmethod
  17. def get_impl_cls():
  18. raise NotImplementedError
  19. @staticmethod
  20. def get_metadata_cls() -> Type["AttentionMetadata"]:
  21. return AttentionMetadata
  22. @staticmethod
  23. def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
  24. raise AttentionMetadataBuilder
  25. @staticmethod
  26. def get_kv_cache_shape(
  27. num_blocks: int,
  28. block_size: int,
  29. num_kv_heads: int,
  30. head_size: int,
  31. ) -> Tuple[int, ...]:
  32. raise NotImplementedError
  33. @staticmethod
  34. def swap_blocks(
  35. src_kv_cache: torch.Tensor,
  36. dst_kv_cache: torch.Tensor,
  37. src_to_dst: torch.Tensor,
  38. ) -> None:
  39. pass
  40. @staticmethod
  41. def copy_blocks(
  42. kv_caches: List[torch.Tensor],
  43. src_to_dists: torch.Tensor,
  44. ) -> None:
  45. pass
  46. def test_model_runner_input():
  47. sampling_metadata = SamplingMetadata(
  48. ["seq_group"],
  49. "selected_token_indices",
  50. "categorized_sample_indices",
  51. "num_prompts",
  52. )
  53. attn_metadata = AttentionMetadata(
  54. num_prefills=1,
  55. num_prefill_tokens=2,
  56. num_decode_tokens=3,
  57. slot_mapping=torch.zeros(1),
  58. )
  59. model_input = ModelInputForGPUWithSamplingMetadata(
  60. input_tokens=torch.ones(10),
  61. input_positions=torch.ones(10),
  62. sampling_metadata=sampling_metadata,
  63. attn_metadata=attn_metadata)
  64. assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
  65. # Test round trip serialization.
  66. tensor_dict = model_input.as_broadcastable_tensor_dict()
  67. attn_backend = MockAttentionBackend()
  68. received_model_input = (
  69. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  70. tensor_dict, attn_backend=attn_backend))
  71. # Check that received copy has correct values.
  72. assert isinstance(received_model_input,
  73. ModelInputForGPUWithSamplingMetadata)
  74. assert received_model_input.input_tokens is not None
  75. assert (
  76. received_model_input.input_tokens == model_input.input_tokens).all()
  77. assert received_model_input.input_positions is not None
  78. assert (received_model_input.input_positions == model_input.input_positions
  79. ).all()
  80. assert received_model_input.multi_modal_kwargs is None
  81. assert (received_model_input.multi_modal_kwargs ==
  82. model_input.multi_modal_kwargs)
  83. assert received_model_input.lora_requests is None
  84. assert received_model_input.lora_requests == model_input.lora_requests
  85. assert received_model_input.lora_mapping is None
  86. assert received_model_input.lora_mapping == model_input.lora_mapping
  87. for field in dataclasses.fields(AttentionMetadata):
  88. assert getattr(received_model_input.attn_metadata, field.name,
  89. None) == getattr(attn_metadata, field.name, None)
  90. # For sampling metadata, only selected_token_indices is copied.
  91. assert (received_model_input.sampling_metadata.selected_token_indices ==
  92. sampling_metadata.selected_token_indices)
  93. assert received_model_input.sampling_metadata.seq_groups is None
  94. def test_embedding_model_runner_input():
  95. pooling_metadata = PoolingMetadata(
  96. seq_groups=[[0]],
  97. seq_data={},
  98. prompt_lens=[1],
  99. )
  100. attn_metadata = AttentionMetadata(
  101. num_prefills=1,
  102. num_prefill_tokens=2,
  103. num_decode_tokens=3,
  104. slot_mapping=torch.zeros(1),
  105. )
  106. model_input = ModelInputForGPUWithPoolingMetadata(
  107. input_tokens=torch.ones(10),
  108. input_positions=torch.ones(10),
  109. pooling_metadata=pooling_metadata,
  110. attn_metadata=attn_metadata)
  111. assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
  112. # Test round trip serialization.
  113. tensor_dict = model_input.as_broadcastable_tensor_dict()
  114. attn_backend = MockAttentionBackend()
  115. received_model_input = (
  116. ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
  117. tensor_dict, attn_backend=attn_backend))
  118. # Check that received copy has correct values.
  119. assert isinstance(received_model_input,
  120. ModelInputForGPUWithPoolingMetadata)
  121. assert received_model_input.input_tokens is not None
  122. assert (
  123. received_model_input.input_tokens == model_input.input_tokens).all()
  124. assert received_model_input.input_positions is not None
  125. assert (received_model_input.input_positions == model_input.input_positions
  126. ).all()
  127. assert received_model_input.multi_modal_kwargs is None
  128. assert (received_model_input.multi_modal_kwargs ==
  129. model_input.multi_modal_kwargs)
  130. assert received_model_input.lora_requests is None
  131. assert received_model_input.lora_requests == model_input.lora_requests
  132. assert received_model_input.lora_mapping is None
  133. assert received_model_input.lora_mapping == model_input.lora_mapping
  134. for field in dataclasses.fields(AttentionMetadata):
  135. assert getattr(received_model_input.attn_metadata, field.name,
  136. None) == getattr(attn_metadata, field.name, None)
  137. # Pooling metadata is not broadcast.
  138. assert received_model_input.pooling_metadata is None
  139. def test_multi_step_model_runner_input():
  140. sampling_metadata = SamplingMetadata(
  141. ["seq_group"],
  142. "selected_token_indices",
  143. "categorized_sample_indices",
  144. "num_prompts",
  145. )
  146. attn_metadata = AttentionMetadata(
  147. num_prefills=1,
  148. num_prefill_tokens=2,
  149. num_decode_tokens=3,
  150. slot_mapping=torch.zeros(1),
  151. )
  152. frozen_model_input = ModelInputForGPUWithSamplingMetadata(
  153. input_tokens=torch.ones(10),
  154. input_positions=torch.ones(10),
  155. sampling_metadata=sampling_metadata,
  156. attn_metadata=attn_metadata)
  157. model_input = StatefulModelInput(
  158. frozen_model_input=frozen_model_input,
  159. is_last_step=True,
  160. is_first_multi_step=False,
  161. current_step=4,
  162. last_sampled_token_ids=torch.ones((10, 1)),
  163. is_multi_step=True,
  164. num_queries=8,
  165. num_seqs=5,
  166. cached_outputs=[],
  167. )
  168. assert isinstance(model_input, StatefulModelInput)
  169. # Test round trip serialization.
  170. tensor_dict = model_input.as_broadcastable_tensor_dict()
  171. attn_backend = MockAttentionBackend()
  172. received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
  173. tensor_dict, attn_backend=attn_backend))
  174. receieved_frozen_input = received_model_input.frozen_model_input
  175. # Check that received copy has correct values.
  176. assert isinstance(received_model_input, StatefulModelInput)
  177. assert receieved_frozen_input.input_tokens is not None
  178. assert (receieved_frozen_input.input_tokens ==
  179. frozen_model_input.input_tokens).all()
  180. assert receieved_frozen_input.input_positions is not None
  181. assert (receieved_frozen_input.input_positions ==
  182. frozen_model_input.input_positions).all()
  183. assert receieved_frozen_input.multi_modal_kwargs is None
  184. assert (frozen_model_input.multi_modal_kwargs ==
  185. frozen_model_input.multi_modal_kwargs)
  186. assert receieved_frozen_input.lora_requests is None
  187. assert (receieved_frozen_input.lora_requests ==
  188. frozen_model_input.lora_requests)
  189. assert receieved_frozen_input.lora_mapping is None
  190. assert (
  191. receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
  192. for field in dataclasses.fields(AttentionMetadata):
  193. assert getattr(receieved_frozen_input.attn_metadata, field.name,
  194. None) == getattr(attn_metadata, field.name, None)
  195. # For sampling metadata, only selected_token_indices is copied.
  196. assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
  197. sampling_metadata.selected_token_indices)
  198. assert receieved_frozen_input.sampling_metadata.seq_groups is None
  199. # check non frozen fields
  200. assert received_model_input.is_last_step == model_input.is_last_step
  201. assert (received_model_input.is_first_multi_step ==
  202. model_input.is_first_multi_step)
  203. assert received_model_input.current_step == model_input.current_step
  204. assert (received_model_input.last_sampled_token_ids ==
  205. model_input.last_sampled_token_ids).all()
  206. assert received_model_input.is_multi_step == model_input.is_multi_step