neuron_model_runner.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. from dataclasses import dataclass
  2. from importlib.util import find_spec
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
  4. import torch
  5. from loguru import logger
  6. from torch import nn
  7. from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig)
  9. from aphrodite.common.sequence import (IntermediateTensors,
  10. SequenceGroupMetadata)
  11. from aphrodite.common.utils import (is_pin_memory_available,
  12. make_tensor_with_pad)
  13. from aphrodite.modeling.layers.sampler import SamplerOutput
  14. from aphrodite.modeling.model_loader.neuron import get_neuron_model
  15. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  16. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  17. MultiModalInputs)
  18. from aphrodite.worker.model_runner_base import (ModelRunnerBase,
  19. ModelRunnerInputBase)
  20. if TYPE_CHECKING:
  21. from aphrodite.attention.backends.abstract import AttentionBackend
  22. @dataclass(frozen=True)
  23. class ModelInputForNeuron(ModelRunnerInputBase):
  24. """
  25. Used by the NeuronModelRunner.
  26. """
  27. input_tokens: Optional[torch.Tensor] = None
  28. input_positions: Optional[torch.Tensor] = None
  29. input_block_ids: Optional[torch.Tensor] = None
  30. sampling_metadata: Optional["SamplingMetadata"] = None
  31. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  32. def as_broadcastable_tensor_dict(
  33. self) -> Dict[str, Union[int, torch.Tensor]]:
  34. raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")
  35. @classmethod
  36. def from_broadcasted_tensor_dict(
  37. cls,
  38. tensor_dict: Dict[str, Any],
  39. attn_backend: Optional["AttentionBackend"] = None,
  40. ) -> "ModelInputForNeuron":
  41. assert attn_backend is None
  42. return cls.from_broadcasted_tensor_dict(tensor_dict)
  43. class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
  44. def __init__(
  45. self,
  46. model_config: ModelConfig,
  47. parallel_config: ParallelConfig,
  48. scheduler_config: SchedulerConfig,
  49. device_config: DeviceConfig,
  50. **kwargs,
  51. ):
  52. self.model_config = model_config
  53. self.parallel_config = parallel_config
  54. self.scheduler_config = scheduler_config
  55. if model_config is not None and model_config.get_sliding_window():
  56. logger.warning("Sliding window is not supported on Neuron. "
  57. "The model will run without sliding window.")
  58. self.device_config = (device_config
  59. if device_config is not None else DeviceConfig())
  60. self.device = self.device_config.device
  61. self.pin_memory = is_pin_memory_available()
  62. # Multi-modal data support
  63. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  64. .create_input_mapper(self.model_config)
  65. # Lazy initialization.
  66. self.model: nn.Module # initialize after load_model.
  67. def load_model(self) -> None:
  68. if find_spec("transformers_neuronx") is not None:
  69. self.model = get_neuron_model(
  70. self.model_config,
  71. parallel_config=self.parallel_config,
  72. scheduler_config=self.scheduler_config)
  73. else:
  74. raise NotImplementedError(
  75. "Supports only Transformer-NeuronX based models.")
  76. def _prepare_prompt(
  77. self,
  78. seq_group_metadata_list: List[SequenceGroupMetadata],
  79. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
  80. BatchedTensorInputs]:
  81. assert len(seq_group_metadata_list) > 0
  82. input_tokens: List[List[int]] = []
  83. input_positions: List[List[int]] = []
  84. input_block_ids: List[int] = []
  85. seq_lens: List[int] = []
  86. multi_modal_inputs_list: List[MultiModalInputs] = []
  87. for seq_group_metadata in seq_group_metadata_list:
  88. assert seq_group_metadata.is_prompt
  89. seq_ids = list(seq_group_metadata.seq_data.keys())
  90. assert len(seq_ids) == 1
  91. seq_id = seq_ids[0]
  92. seq_data = seq_group_metadata.seq_data[seq_id]
  93. prompt_tokens = seq_data.get_token_ids()
  94. seq_len = len(prompt_tokens)
  95. seq_lens.append(seq_len)
  96. input_tokens.append(prompt_tokens)
  97. input_positions.append(list(range(seq_len)))
  98. assert seq_group_metadata.block_tables is not None
  99. block_table = seq_group_metadata.block_tables[seq_id]
  100. assert len(block_table) == 1
  101. input_block_ids.append(block_table[0])
  102. mm_data = seq_group_metadata.multi_modal_data
  103. if mm_data:
  104. # Process multi-modal data
  105. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  106. multi_modal_inputs_list.append(mm_kwargs)
  107. max_seq_len = max(seq_lens)
  108. assert max_seq_len > 0
  109. input_tokens = make_tensor_with_pad(input_tokens,
  110. max_seq_len,
  111. pad=0,
  112. dtype=torch.long,
  113. device=self.device)
  114. input_positions = make_tensor_with_pad(input_positions,
  115. max_seq_len,
  116. pad=0,
  117. dtype=torch.long,
  118. device=self.device)
  119. input_block_ids = torch.tensor(input_block_ids,
  120. dtype=torch.long,
  121. device=self.device)
  122. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  123. return (input_tokens, input_positions, input_block_ids, seq_lens,
  124. multi_modal_kwargs)
  125. def _prepare_decode(
  126. self,
  127. seq_group_metadata_list: List[SequenceGroupMetadata],
  128. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  129. assert len(seq_group_metadata_list) > 0
  130. input_tokens: List[List[int]] = []
  131. input_positions: List[List[int]] = []
  132. input_block_ids: List[int] = []
  133. context_lens: List[int] = []
  134. for seq_group_metadata in seq_group_metadata_list:
  135. assert not seq_group_metadata.is_prompt
  136. seq_ids = list(seq_group_metadata.seq_data.keys())
  137. for seq_id in seq_ids:
  138. seq_data = seq_group_metadata.seq_data[seq_id]
  139. generation_token = seq_data.get_last_token_id()
  140. input_tokens.append([generation_token])
  141. seq_len = seq_data.get_len()
  142. position = seq_len - 1
  143. input_positions.append([position])
  144. context_lens.append(seq_len)
  145. assert seq_group_metadata.block_tables is not None
  146. block_table = seq_group_metadata.block_tables[seq_id]
  147. assert len(block_table) == 1
  148. input_block_ids.append(block_table[0])
  149. input_tokens = make_tensor_with_pad(input_tokens,
  150. max_len=1,
  151. pad=0,
  152. dtype=torch.long,
  153. device=self.device)
  154. input_positions = make_tensor_with_pad(input_positions,
  155. max_len=1,
  156. pad=0,
  157. dtype=torch.long,
  158. device=self.device)
  159. context_lens = torch.tensor(context_lens,
  160. dtype=torch.int,
  161. device=self.device)
  162. input_block_ids = torch.tensor(input_block_ids,
  163. dtype=torch.long,
  164. device=self.device)
  165. return input_tokens, input_positions, input_block_ids
  166. def make_model_input_from_broadcasted_tensor_dict(
  167. self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
  168. return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
  169. def prepare_model_input(
  170. self,
  171. seq_group_metadata_list: List[SequenceGroupMetadata],
  172. virtual_engine: int = 0,
  173. finished_requests_ids: Optional[List[str]] = None
  174. ) -> ModelInputForNeuron:
  175. multi_modal_kwargs = None
  176. # NOTE: We assume that all sequences in the group are all prompts or
  177. # all decodes.
  178. is_prompt = seq_group_metadata_list[0].is_prompt
  179. # Prepare input tensors.
  180. if is_prompt:
  181. (input_tokens, input_positions, input_block_ids, seq_lens,
  182. multi_modal_kwargs
  183. ) = self._prepare_prompt(seq_group_metadata_list)
  184. else:
  185. (input_tokens, input_positions,
  186. input_block_ids) = self._prepare_decode(seq_group_metadata_list)
  187. seq_lens = []
  188. sampling_metadata = SamplingMetadata.prepare(
  189. seq_group_metadata_list,
  190. seq_lens,
  191. # query_lens is not needed if chunked prefill is not
  192. # supported. Since neuron worker doesn't support chunked prefill
  193. # just use seq_lens instead.
  194. seq_lens,
  195. self.device,
  196. self.pin_memory,
  197. generators=self.get_generators(finished_requests_ids))
  198. return ModelInputForNeuron(input_tokens=input_tokens,
  199. input_positions=input_positions,
  200. input_block_ids=input_block_ids,
  201. sampling_metadata=sampling_metadata,
  202. multi_modal_kwargs=multi_modal_kwargs)
  203. @torch.inference_mode()
  204. def execute_model(
  205. self,
  206. model_input: ModelInputForNeuron,
  207. kv_caches: Optional[List[torch.Tensor]] = None,
  208. intermediate_tensors: Optional[IntermediateTensors] = None,
  209. num_steps: int = 1,
  210. ) -> Optional[List[SamplerOutput]]:
  211. if num_steps > 1:
  212. raise ValueError(
  213. "NeuronModelRunner does not support multi-step execution.")
  214. hidden_states = self.model(
  215. input_ids=model_input.input_tokens,
  216. positions=model_input.input_positions,
  217. input_block_ids=model_input.input_block_ids,
  218. **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
  219. device=self.device),
  220. )
  221. # Compute the logits.
  222. logits = self.model.compute_logits(hidden_states,
  223. model_input.sampling_metadata)
  224. # Sample the next token.
  225. output = self.model.sample(
  226. logits=logits,
  227. sampling_metadata=model_input.sampling_metadata,
  228. )
  229. return [output]
  230. @property
  231. def vocab_size(self) -> int:
  232. return self.model_config.get_vocab_size()