1
0

neuron_model_runner.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. from loguru import logger
  4. from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig,
  5. SchedulerConfig)
  6. from aphrodite.modeling import SamplingMetadata
  7. from aphrodite.modeling.neuron_loader import get_neuron_model
  8. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  9. from aphrodite.common.sequence import (SamplerOutput, SequenceData,
  10. SequenceGroupMetadata)
  11. from aphrodite.common.utils import (async_tensor_h2d, is_pin_memory_available,
  12. make_tensor_with_pad, maybe_expand_dim)
  13. class NeuronModelRunner:
  14. def __init__(
  15. self,
  16. model_config: ModelConfig,
  17. parallel_config: ParallelConfig,
  18. scheduler_config: SchedulerConfig,
  19. device_config: DeviceConfig,
  20. ):
  21. self.model_config = model_config
  22. self.parallel_config = parallel_config
  23. self.scheduler_config = scheduler_config
  24. if model_config is not None and model_config.get_sliding_window():
  25. logger.warning("Sliding window is not supported on Neuron. "
  26. "The model will run without sliding window.")
  27. self.device_config = (device_config
  28. if device_config is not None else DeviceConfig())
  29. self.device = self.device_config.device
  30. self.model = None
  31. self.pin_memory = is_pin_memory_available()
  32. def load_model(self) -> None:
  33. self.model = get_neuron_model(self.model_config,
  34. parallel_config=self.parallel_config,
  35. scheduler_config=self.scheduler_config)
  36. def _prepare_prompt(
  37. self,
  38. seq_group_metadata_list: List[SequenceGroupMetadata],
  39. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
  40. assert len(seq_group_metadata_list) > 0
  41. input_tokens: List[List[int]] = []
  42. input_positions: List[List[int]] = []
  43. input_block_ids: List[int] = []
  44. prompt_lens: List[int] = []
  45. for seq_group_metadata in seq_group_metadata_list:
  46. assert seq_group_metadata.is_prompt
  47. seq_ids = list(seq_group_metadata.seq_data.keys())
  48. assert len(seq_ids) == 1
  49. seq_id = seq_ids[0]
  50. seq_data = seq_group_metadata.seq_data[seq_id]
  51. prompt_tokens = seq_data.get_token_ids()
  52. prompt_len = len(prompt_tokens)
  53. prompt_lens.append(prompt_len)
  54. input_tokens.append(prompt_tokens)
  55. input_positions.append(list(range(prompt_len)))
  56. assert seq_group_metadata.block_tables is not None
  57. block_table = seq_group_metadata.block_tables[seq_id]
  58. assert len(block_table) == 1
  59. input_block_ids.append(block_table[0])
  60. max_prompt_len = max(prompt_lens)
  61. assert max_prompt_len > 0
  62. input_tokens = make_tensor_with_pad(input_tokens,
  63. max_prompt_len,
  64. pad=0,
  65. dtype=torch.long,
  66. device=self.device)
  67. input_positions = make_tensor_with_pad(input_positions,
  68. max_prompt_len,
  69. pad=0,
  70. dtype=torch.long,
  71. device=self.device)
  72. input_block_ids = torch.tensor(input_block_ids,
  73. dtype=torch.long,
  74. device=self.device)
  75. return input_tokens, input_positions, input_block_ids, prompt_lens
  76. def _prepare_decode(
  77. self,
  78. seq_group_metadata_list: List[SequenceGroupMetadata],
  79. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  80. assert len(seq_group_metadata_list) > 0
  81. input_tokens: List[List[int]] = []
  82. input_positions: List[List[int]] = []
  83. input_block_ids: List[int] = []
  84. context_lens: List[int] = []
  85. for seq_group_metadata in seq_group_metadata_list:
  86. assert not seq_group_metadata.is_prompt
  87. seq_ids = list(seq_group_metadata.seq_data.keys())
  88. for seq_id in seq_ids:
  89. seq_data = seq_group_metadata.seq_data[seq_id]
  90. generation_token = seq_data.get_last_token_id()
  91. input_tokens.append([generation_token])
  92. seq_len = seq_data.get_len()
  93. position = seq_len - 1
  94. input_positions.append([position])
  95. context_lens.append(seq_len)
  96. assert seq_group_metadata.block_tables is not None
  97. block_table = seq_group_metadata.block_tables[seq_id]
  98. assert len(block_table) == 1
  99. input_block_ids.append(block_table[0])
  100. input_tokens = make_tensor_with_pad(input_tokens,
  101. max_len=1,
  102. pad=0,
  103. dtype=torch.long,
  104. device=self.device)
  105. input_positions = make_tensor_with_pad(input_positions,
  106. max_len=1,
  107. pad=0,
  108. dtype=torch.long,
  109. device=self.device)
  110. context_lens = torch.tensor(context_lens,
  111. dtype=torch.int,
  112. device=self.device)
  113. input_block_ids = torch.tensor(input_block_ids,
  114. dtype=torch.long,
  115. device=self.device)
  116. return input_tokens, input_positions, input_block_ids
  117. def _prepare_sample(
  118. self,
  119. seq_group_metadata_list: List[SequenceGroupMetadata],
  120. prompt_lens: List[int],
  121. ) -> SamplingMetadata:
  122. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  123. selected_token_indices: List[int] = []
  124. generators: List[torch.Generator] = []
  125. selected_token_start_idx = 0
  126. categorized_sample_indices = {t: [] for t in SamplingType}
  127. categorized_sample_indices_start_idx = 0
  128. categorized_sampled_token_indices_start_idx = 0
  129. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  130. seq_ids = list(seq_group_metadata.seq_data.keys())
  131. sampling_params = seq_group_metadata.sampling_params
  132. seq_groups.append((seq_ids, sampling_params))
  133. if seq_group_metadata.is_prompt:
  134. assert len(seq_ids) == 1
  135. assert prompt_lens is not None
  136. prompt_len = prompt_lens[i]
  137. if sampling_params.prompt_logprobs is not None:
  138. # NOTE: prompt token positions do not need sample, skip
  139. categorized_sample_indices_start_idx += prompt_len - 1
  140. categorized_sample_indices[
  141. sampling_params.sampling_type].append([
  142. categorized_sample_indices_start_idx,
  143. categorized_sampled_token_indices_start_idx
  144. ])
  145. categorized_sample_indices_start_idx += 1
  146. categorized_sampled_token_indices_start_idx += 1
  147. if sampling_params.prompt_logprobs is not None:
  148. selected_token_indices.extend(
  149. range(selected_token_start_idx,
  150. selected_token_start_idx + prompt_len - 1))
  151. selected_token_indices.append(selected_token_start_idx +
  152. prompt_len - 1)
  153. selected_token_start_idx += prompt_len
  154. if sampling_params.seed is not None:
  155. seq_group_metadata.state.generator = torch.Generator(
  156. device=self.device).manual_seed(sampling_params.seed)
  157. else:
  158. num_seqs = len(seq_ids)
  159. selected_token_indices.extend(
  160. range(selected_token_start_idx,
  161. selected_token_start_idx + num_seqs))
  162. selected_token_start_idx += num_seqs
  163. categorized_sample_indices[
  164. sampling_params.sampling_type].extend(
  165. zip(
  166. range(
  167. categorized_sample_indices_start_idx,
  168. categorized_sample_indices_start_idx +
  169. num_seqs),
  170. range(
  171. categorized_sampled_token_indices_start_idx,
  172. categorized_sampled_token_indices_start_idx +
  173. num_seqs)))
  174. categorized_sample_indices_start_idx += num_seqs
  175. categorized_sampled_token_indices_start_idx += num_seqs
  176. if sampling_params.seed is not None:
  177. generators.append(seq_group_metadata.state.generator)
  178. selected_token_indices = async_tensor_h2d(selected_token_indices,
  179. dtype=torch.long,
  180. target_device=self.device,
  181. pin_memory=self.pin_memory)
  182. categorized_sample_indices = {
  183. t: maybe_expand_dim(
  184. async_tensor_h2d(seq_ids,
  185. dtype=torch.int,
  186. target_device=self.device,
  187. pin_memory=self.pin_memory), 2, 2)
  188. for t, seq_ids in categorized_sample_indices.items()
  189. }
  190. seq_data: Dict[int, SequenceData] = {}
  191. for seq_group_metadata in seq_group_metadata_list:
  192. seq_data.update(seq_group_metadata.seq_data)
  193. sampling_metadata = SamplingMetadata(
  194. seq_groups=seq_groups,
  195. seq_data=seq_data,
  196. prompt_lens=prompt_lens,
  197. selected_token_indices=selected_token_indices,
  198. categorized_sample_indices=categorized_sample_indices,
  199. generators=generators,
  200. )
  201. return sampling_metadata
  202. def prepare_input_tensors(
  203. self,
  204. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  205. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
  206. # NOTE: We assume that all sequences in the group are all prompts or
  207. # all decodes.
  208. is_prompt = seq_group_metadata_list[0].is_prompt
  209. # Prepare input tensors.
  210. if is_prompt:
  211. (input_tokens, input_positions, input_block_ids,
  212. prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
  213. else:
  214. (input_tokens, input_positions,
  215. input_block_ids) = self._prepare_decode(seq_group_metadata_list)
  216. prompt_lens = []
  217. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  218. prompt_lens)
  219. return (input_tokens, input_positions, input_block_ids,
  220. sampling_metadata)
  221. @torch.inference_mode()
  222. def execute_model(
  223. self,
  224. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  225. ) -> Optional[SamplerOutput]:
  226. (input_tokens, input_positions, input_block_ids, sampling_metadata
  227. ) = self.prepare_input_tensors(seq_group_metadata_list)
  228. hidden_states = self.model(
  229. input_ids=input_tokens,
  230. positions=input_positions,
  231. input_block_ids=input_block_ids,
  232. )
  233. # Compute the logits.
  234. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  235. # Sample the next token.
  236. output = self.model.sample(
  237. logits=logits,
  238. sampling_metadata=sampling_metadata,
  239. )
  240. return output
  241. @property
  242. def vocab_size(self) -> int:
  243. return self.model_config.get_vocab_size()