neuron_model_runner.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from loguru import logger
  4. from torch import nn
  5. from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig,
  6. SchedulerConfig)
  7. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  8. from aphrodite.common.utils import (is_pin_memory_available,
  9. make_tensor_with_pad)
  10. from aphrodite.modeling import SamplingMetadata
  11. from aphrodite.modeling.model_loader.neuron import get_neuron_model
  12. class NeuronModelRunner:
  13. def __init__(
  14. self,
  15. model_config: ModelConfig,
  16. parallel_config: ParallelConfig,
  17. scheduler_config: SchedulerConfig,
  18. device_config: DeviceConfig,
  19. ):
  20. self.model_config = model_config
  21. self.parallel_config = parallel_config
  22. self.scheduler_config = scheduler_config
  23. if model_config is not None and model_config.get_sliding_window():
  24. logger.warning("Sliding window is not supported on Neuron. "
  25. "The model will run without sliding window.")
  26. self.device_config = (device_config
  27. if device_config is not None else DeviceConfig())
  28. self.device = self.device_config.device
  29. self.pin_memory = is_pin_memory_available()
  30. # Lazy initialization.
  31. self.model: nn.Module # initialize after load_model.
  32. def load_model(self) -> None:
  33. self.model = get_neuron_model(self.model_config,
  34. parallel_config=self.parallel_config,
  35. scheduler_config=self.scheduler_config)
  36. def _prepare_prompt(
  37. self,
  38. seq_group_metadata_list: List[SequenceGroupMetadata],
  39. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
  40. assert len(seq_group_metadata_list) > 0
  41. input_tokens: List[List[int]] = []
  42. input_positions: List[List[int]] = []
  43. input_block_ids: List[int] = []
  44. seq_lens: List[int] = []
  45. for seq_group_metadata in seq_group_metadata_list:
  46. assert seq_group_metadata.is_prompt
  47. seq_ids = list(seq_group_metadata.seq_data.keys())
  48. assert len(seq_ids) == 1
  49. seq_id = seq_ids[0]
  50. seq_data = seq_group_metadata.seq_data[seq_id]
  51. prompt_tokens = seq_data.get_token_ids()
  52. seq_len = len(prompt_tokens)
  53. seq_lens.append(seq_len)
  54. input_tokens.append(prompt_tokens)
  55. input_positions.append(list(range(seq_len)))
  56. assert seq_group_metadata.block_tables is not None
  57. block_table = seq_group_metadata.block_tables[seq_id]
  58. assert len(block_table) == 1
  59. input_block_ids.append(block_table[0])
  60. max_seq_len = max(seq_lens)
  61. assert max_seq_len > 0
  62. input_tokens = make_tensor_with_pad(input_tokens,
  63. max_seq_len,
  64. pad=0,
  65. dtype=torch.long,
  66. device=self.device)
  67. input_positions = make_tensor_with_pad(input_positions,
  68. max_seq_len,
  69. pad=0,
  70. dtype=torch.long,
  71. device=self.device)
  72. input_block_ids = torch.tensor(input_block_ids,
  73. dtype=torch.long,
  74. device=self.device)
  75. return input_tokens, input_positions, input_block_ids, seq_lens
  76. def _prepare_decode(
  77. self,
  78. seq_group_metadata_list: List[SequenceGroupMetadata],
  79. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  80. assert len(seq_group_metadata_list) > 0
  81. input_tokens: List[List[int]] = []
  82. input_positions: List[List[int]] = []
  83. input_block_ids: List[int] = []
  84. context_lens: List[int] = []
  85. for seq_group_metadata in seq_group_metadata_list:
  86. assert not seq_group_metadata.is_prompt
  87. seq_ids = list(seq_group_metadata.seq_data.keys())
  88. for seq_id in seq_ids:
  89. seq_data = seq_group_metadata.seq_data[seq_id]
  90. generation_token = seq_data.get_last_token_id()
  91. input_tokens.append([generation_token])
  92. seq_len = seq_data.get_len()
  93. position = seq_len - 1
  94. input_positions.append([position])
  95. context_lens.append(seq_len)
  96. assert seq_group_metadata.block_tables is not None
  97. block_table = seq_group_metadata.block_tables[seq_id]
  98. assert len(block_table) == 1
  99. input_block_ids.append(block_table[0])
  100. input_tokens = make_tensor_with_pad(input_tokens,
  101. max_len=1,
  102. pad=0,
  103. dtype=torch.long,
  104. device=self.device)
  105. input_positions = make_tensor_with_pad(input_positions,
  106. max_len=1,
  107. pad=0,
  108. dtype=torch.long,
  109. device=self.device)
  110. context_lens = torch.tensor(context_lens,
  111. dtype=torch.int,
  112. device=self.device)
  113. input_block_ids = torch.tensor(input_block_ids,
  114. dtype=torch.long,
  115. device=self.device)
  116. return input_tokens, input_positions, input_block_ids
  117. def prepare_input_tensors(
  118. self,
  119. seq_group_metadata_list: List[SequenceGroupMetadata],
  120. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
  121. # NOTE: We assume that all sequences in the group are all prompts or
  122. # all decodes.
  123. is_prompt = seq_group_metadata_list[0].is_prompt
  124. # Prepare input tensors.
  125. if is_prompt:
  126. (input_tokens, input_positions, input_block_ids,
  127. seq_lens) = self._prepare_prompt(seq_group_metadata_list)
  128. else:
  129. (input_tokens, input_positions,
  130. input_block_ids) = self._prepare_decode(seq_group_metadata_list)
  131. seq_lens = []
  132. sampling_metadata = SamplingMetadata.prepare(
  133. seq_group_metadata_list,
  134. seq_lens,
  135. # query_lens is not needed if chunked prefill is not
  136. # supported. Since neuron worker doesn't support chunked prefill
  137. # just use seq_lens instead.
  138. seq_lens,
  139. self.device,
  140. self.pin_memory)
  141. return (input_tokens, input_positions, input_block_ids,
  142. sampling_metadata)
  143. @torch.inference_mode()
  144. def execute_model(
  145. self,
  146. seq_group_metadata_list: List[SequenceGroupMetadata],
  147. ) -> Optional[SamplerOutput]:
  148. (input_tokens, input_positions, input_block_ids, sampling_metadata
  149. ) = self.prepare_input_tensors(seq_group_metadata_list)
  150. hidden_states = self.model(
  151. input_ids=input_tokens,
  152. positions=input_positions,
  153. input_block_ids=input_block_ids,
  154. )
  155. # Compute the logits.
  156. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  157. # Sample the next token.
  158. output = self.model.sample(
  159. logits=logits,
  160. sampling_metadata=sampling_metadata,
  161. )
  162. return output
  163. @property
  164. def vocab_size(self) -> int:
  165. return self.model_config.get_vocab_size()