neuron_model_runner.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. from loguru import logger
  4. from torch import nn
  5. from aphrodite.common.sequence import (SamplerOutput, SequenceData,
  6. SequenceGroupMetadata)
  7. from aphrodite.common.utils import (async_tensor_h2d, is_pin_memory_available,
  8. make_tensor_with_pad, maybe_expand_dim)
  9. from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig,
  10. SchedulerConfig)
  11. from aphrodite.modeling import SamplingMetadata
  12. from aphrodite.modeling.model_loader.neuron import get_neuron_model
  13. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  14. class NeuronModelRunner:
  15. def __init__(
  16. self,
  17. model_config: ModelConfig,
  18. parallel_config: ParallelConfig,
  19. scheduler_config: SchedulerConfig,
  20. device_config: DeviceConfig,
  21. ):
  22. self.model_config = model_config
  23. self.parallel_config = parallel_config
  24. self.scheduler_config = scheduler_config
  25. if model_config is not None and model_config.get_sliding_window():
  26. logger.warning("Sliding window is not supported on Neuron. "
  27. "The model will run without sliding window.")
  28. self.device_config = (device_config
  29. if device_config is not None else DeviceConfig())
  30. self.device = self.device_config.device
  31. self.pin_memory = is_pin_memory_available()
  32. # Lazy initialization.
  33. self.model: nn.Module # initialize after load_model.
  34. def load_model(self) -> None:
  35. self.model = get_neuron_model(self.model_config,
  36. parallel_config=self.parallel_config,
  37. scheduler_config=self.scheduler_config)
  38. def _prepare_prompt(
  39. self,
  40. seq_group_metadata_list: List[SequenceGroupMetadata],
  41. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
  42. assert len(seq_group_metadata_list) > 0
  43. input_tokens: List[List[int]] = []
  44. input_positions: List[List[int]] = []
  45. input_block_ids: List[int] = []
  46. prompt_lens: List[int] = []
  47. for seq_group_metadata in seq_group_metadata_list:
  48. assert seq_group_metadata.is_prompt
  49. seq_ids = list(seq_group_metadata.seq_data.keys())
  50. assert len(seq_ids) == 1
  51. seq_id = seq_ids[0]
  52. seq_data = seq_group_metadata.seq_data[seq_id]
  53. prompt_tokens = seq_data.get_token_ids()
  54. prompt_len = len(prompt_tokens)
  55. prompt_lens.append(prompt_len)
  56. input_tokens.append(prompt_tokens)
  57. input_positions.append(list(range(prompt_len)))
  58. assert seq_group_metadata.block_tables is not None
  59. block_table = seq_group_metadata.block_tables[seq_id]
  60. assert len(block_table) == 1
  61. input_block_ids.append(block_table[0])
  62. max_prompt_len = max(prompt_lens)
  63. assert max_prompt_len > 0
  64. input_tokens = make_tensor_with_pad(input_tokens,
  65. max_prompt_len,
  66. pad=0,
  67. dtype=torch.long,
  68. device=self.device)
  69. input_positions = make_tensor_with_pad(input_positions,
  70. max_prompt_len,
  71. pad=0,
  72. dtype=torch.long,
  73. device=self.device)
  74. input_block_ids = torch.tensor(input_block_ids,
  75. dtype=torch.long,
  76. device=self.device)
  77. return input_tokens, input_positions, input_block_ids, prompt_lens
  78. def _prepare_decode(
  79. self,
  80. seq_group_metadata_list: List[SequenceGroupMetadata],
  81. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  82. assert len(seq_group_metadata_list) > 0
  83. input_tokens: List[List[int]] = []
  84. input_positions: List[List[int]] = []
  85. input_block_ids: List[int] = []
  86. context_lens: List[int] = []
  87. for seq_group_metadata in seq_group_metadata_list:
  88. assert not seq_group_metadata.is_prompt
  89. seq_ids = list(seq_group_metadata.seq_data.keys())
  90. for seq_id in seq_ids:
  91. seq_data = seq_group_metadata.seq_data[seq_id]
  92. generation_token = seq_data.get_last_token_id()
  93. input_tokens.append([generation_token])
  94. seq_len = seq_data.get_len()
  95. position = seq_len - 1
  96. input_positions.append([position])
  97. context_lens.append(seq_len)
  98. assert seq_group_metadata.block_tables is not None
  99. block_table = seq_group_metadata.block_tables[seq_id]
  100. assert len(block_table) == 1
  101. input_block_ids.append(block_table[0])
  102. input_tokens = make_tensor_with_pad(input_tokens,
  103. max_len=1,
  104. pad=0,
  105. dtype=torch.long,
  106. device=self.device)
  107. input_positions = make_tensor_with_pad(input_positions,
  108. max_len=1,
  109. pad=0,
  110. dtype=torch.long,
  111. device=self.device)
  112. context_lens = torch.tensor(context_lens,
  113. dtype=torch.int,
  114. device=self.device)
  115. input_block_ids = torch.tensor(input_block_ids,
  116. dtype=torch.long,
  117. device=self.device)
  118. return input_tokens, input_positions, input_block_ids
  119. def _prepare_sample(
  120. self,
  121. seq_group_metadata_list: List[SequenceGroupMetadata],
  122. prompt_lens: List[int],
  123. ) -> SamplingMetadata:
  124. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  125. selected_token_indices: List[int] = []
  126. generators: List[torch.Generator] = []
  127. selected_token_start_idx = 0
  128. categorized_sample_indices: Dict[SamplingType,
  129. List[Tuple[int, int]]] = {
  130. t: []
  131. for t in SamplingType
  132. }
  133. categorized_sample_indices_start_idx = 0
  134. categorized_sampled_token_indices_start_idx = 0
  135. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  136. seq_ids = list(seq_group_metadata.seq_data.keys())
  137. sampling_params = seq_group_metadata.sampling_params
  138. seq_groups.append((seq_ids, sampling_params))
  139. if seq_group_metadata.is_prompt:
  140. assert len(seq_ids) == 1
  141. assert prompt_lens is not None
  142. prompt_len = prompt_lens[i]
  143. if sampling_params.prompt_logprobs is not None:
  144. # NOTE: prompt token positions do not need sample, skip
  145. categorized_sample_indices_start_idx += prompt_len - 1
  146. categorized_sample_indices[
  147. sampling_params.sampling_type].append(
  148. (categorized_sample_indices_start_idx,
  149. categorized_sampled_token_indices_start_idx))
  150. categorized_sample_indices_start_idx += 1
  151. categorized_sampled_token_indices_start_idx += 1
  152. if sampling_params.prompt_logprobs is not None:
  153. selected_token_indices.extend(
  154. range(selected_token_start_idx,
  155. selected_token_start_idx + prompt_len - 1))
  156. selected_token_indices.append(selected_token_start_idx +
  157. prompt_len - 1)
  158. selected_token_start_idx += prompt_len
  159. if sampling_params.seed is not None:
  160. seq_group_metadata.state.generator = torch.Generator(
  161. device=self.device).manual_seed(sampling_params.seed)
  162. else:
  163. num_seqs = len(seq_ids)
  164. selected_token_indices.extend(
  165. range(selected_token_start_idx,
  166. selected_token_start_idx + num_seqs))
  167. selected_token_start_idx += num_seqs
  168. categorized_sample_indices[
  169. sampling_params.sampling_type].extend(
  170. zip(
  171. range(
  172. categorized_sample_indices_start_idx,
  173. categorized_sample_indices_start_idx +
  174. num_seqs),
  175. range(
  176. categorized_sampled_token_indices_start_idx,
  177. categorized_sampled_token_indices_start_idx +
  178. num_seqs)))
  179. categorized_sample_indices_start_idx += num_seqs
  180. categorized_sampled_token_indices_start_idx += num_seqs
  181. if sampling_params.seed is not None:
  182. generators.append(seq_group_metadata.state.generator)
  183. selected_token_indices = async_tensor_h2d(selected_token_indices,
  184. dtype=torch.long,
  185. target_device=self.device,
  186. pin_memory=self.pin_memory)
  187. categorized_sample_indices = {
  188. t: maybe_expand_dim(
  189. async_tensor_h2d(seq_ids,
  190. dtype=torch.int,
  191. target_device=self.device,
  192. pin_memory=self.pin_memory), 2, 2)
  193. for t, seq_ids in categorized_sample_indices.items()
  194. }
  195. seq_data: Dict[int, SequenceData] = {}
  196. for seq_group_metadata in seq_group_metadata_list:
  197. seq_data.update(seq_group_metadata.seq_data)
  198. sampling_metadata = SamplingMetadata(
  199. seq_groups=seq_groups,
  200. seq_data=seq_data,
  201. prompt_lens=prompt_lens,
  202. selected_token_indices=selected_token_indices,
  203. categorized_sample_indices=categorized_sample_indices,
  204. generators=generators,
  205. )
  206. return sampling_metadata
  207. def prepare_input_tensors(
  208. self,
  209. seq_group_metadata_list: List[SequenceGroupMetadata],
  210. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
  211. # NOTE: We assume that all sequences in the group are all prompts or
  212. # all decodes.
  213. is_prompt = seq_group_metadata_list[0].is_prompt
  214. # Prepare input tensors.
  215. if is_prompt:
  216. (input_tokens, input_positions, input_block_ids,
  217. prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
  218. else:
  219. (input_tokens, input_positions,
  220. input_block_ids) = self._prepare_decode(seq_group_metadata_list)
  221. prompt_lens = []
  222. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  223. prompt_lens)
  224. return (input_tokens, input_positions, input_block_ids,
  225. sampling_metadata)
  226. @torch.inference_mode()
  227. def execute_model(
  228. self,
  229. seq_group_metadata_list: List[SequenceGroupMetadata],
  230. ) -> Optional[SamplerOutput]:
  231. (input_tokens, input_positions, input_block_ids, sampling_metadata
  232. ) = self.prepare_input_tensors(seq_group_metadata_list)
  233. hidden_states = self.model(
  234. input_ids=input_tokens,
  235. positions=input_positions,
  236. input_block_ids=input_block_ids,
  237. )
  238. # Compute the logits.
  239. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  240. # Sample the next token.
  241. output = self.model.sample(
  242. logits=logits,
  243. sampling_metadata=sampling_metadata,
  244. )
  245. return output
  246. @property
  247. def vocab_size(self) -> int:
  248. return self.model_config.get_vocab_size()