neuron.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. """Utilities for selecting and loading neuron models."""
  2. import importlib
  3. import os
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. import transformers
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import (ModelConfig, ParallelConfig,
  10. SchedulerConfig)
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  13. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  14. from aphrodite.quantization import get_quantization_config
  15. TORCH_DTYPE_TO_NEURON_AMP = {
  16. "auto": "f32",
  17. "half": "f16",
  18. "float16": "f16",
  19. "bfloat16": "bf16",
  20. "float": "f32",
  21. "float32": "f32",
  22. torch.float16: "f16",
  23. torch.bfloat16: "bf16",
  24. torch.float32: "f32",
  25. }
  26. # Models supported by Neuron.
  27. _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
  28. "LlamaForCausalLM": ("transformers_neuronx.llama.model",
  29. "LlamaForSampling", "LlamaForCausalLM"),
  30. "MistralForCausalLM": ("transformers_neuronx.mistral.model",
  31. "MistralForSampling", "MistralForCausalLM")
  32. }
  33. class NeuronCasualLM(nn.Module):
  34. def __init__(
  35. self,
  36. config: PretrainedConfig,
  37. ) -> None:
  38. super().__init__()
  39. self.config = config
  40. self.logits_processor = LogitsProcessor(config.vocab_size,
  41. logits_as_input=True)
  42. self.sampler = Sampler()
  43. # Lazy initialized
  44. self.model: nn.Module
  45. def forward(
  46. self,
  47. input_ids: torch.Tensor,
  48. positions: torch.Tensor,
  49. input_block_ids: torch.Tensor,
  50. ) -> torch.Tensor:
  51. logits = self.model(input_ids,
  52. cache_ids=positions,
  53. start_ids=input_block_ids)
  54. return logits
  55. def compute_logits(
  56. self,
  57. hidden_states: torch.Tensor,
  58. sampling_metadata: SamplingMetadata,
  59. ) -> Optional[torch.Tensor]:
  60. logits = self.logits_processor(None, hidden_states, sampling_metadata)
  61. return logits
  62. def sample(
  63. self,
  64. logits: torch.Tensor,
  65. sampling_metadata: SamplingMetadata,
  66. ) -> Optional[SamplerOutput]:
  67. next_tokens = self.sampler(logits, sampling_metadata)
  68. return next_tokens
  69. def load_weights(self, model_name_or_path: str, **kwargs):
  70. arch = _get_model_architecture(self.config)
  71. neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
  72. _NEURON_SUPPORTED_MODELS[arch])
  73. neuronx_module = importlib.import_module(neuronx_module_path)
  74. neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
  75. split_model_dir = f"{model_name_or_path}-split"
  76. if _is_pretrained_neuron_checkpoint(model_name_or_path):
  77. split_model_dir = model_name_or_path
  78. elif not os.path.exists(f"{model_name_or_path}-split"):
  79. hf_model_cls = getattr(transformers, hf_model_cls_name)
  80. from transformers_neuronx.module import save_pretrained_split
  81. hf_model = hf_model_cls.from_pretrained(model_name_or_path,
  82. low_cpu_mem_usage=True)
  83. save_pretrained_split(hf_model, f"{model_name_or_path}-split")
  84. self.model = neuronx_model_cls.from_pretrained(split_model_dir,
  85. **kwargs)
  86. self.model.to_neuron()
  87. def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
  88. # Checking if the neuron checkpoint is saved in the old format.
  89. if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
  90. return True
  91. # Checking if the neuron checkpoint is saved in the new format.
  92. pretrained_split_files = ["config.json", "generation_config.json"]
  93. pretrained_split_format = ".safetensors"
  94. for file in pretrained_split_files:
  95. file_path = os.path.join(model_name_or_path, file)
  96. if not os.path.isfile(file_path):
  97. return False
  98. for file in os.listdir(model_name_or_path):
  99. if file.endswith(pretrained_split_format):
  100. return True
  101. return False
  102. def _get_model_architecture(config: PretrainedConfig) -> str:
  103. architectures = getattr(config, "architectures", [])
  104. for arch in architectures:
  105. if arch in _NEURON_SUPPORTED_MODELS:
  106. return arch
  107. raise ValueError(
  108. f"Model architectures {architectures} are not supported on Neuron "
  109. f"for now. Supported architectures: "
  110. f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
  111. def _get_buckets(env: str, default_value: List[int]) -> List[int]:
  112. env_value = os.getenv(env)
  113. if env_value is None:
  114. return default_value
  115. buckets_remove_empty = filter(
  116. lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
  117. buckets_int = map(int, buckets_remove_empty)
  118. buckets_list = list(buckets_int)
  119. return buckets_list
  120. def _get_default_neuron_config(model_config: ModelConfig,
  121. parallel_config: ParallelConfig,
  122. scheduler_config: SchedulerConfig):
  123. from transformers_neuronx.config import ContinuousBatchingConfig
  124. from transformers_neuronx.constants import LAYOUT_BSH
  125. continuous_batching_config = ContinuousBatchingConfig(
  126. batch_size_for_shared_caches=scheduler_config.max_num_seqs)
  127. quant_config = dict(
  128. dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
  129. quantize_method="vector_dynamic")
  130. neuron_quantization_config_builder = lambda quant: get_quantization_config(
  131. quant).from_config(quant_config).get_quant_method(None, "")
  132. # TODO: Add Paged attention config to the default neuron arguments.
  133. default_neuron_args = dict(
  134. collectives_layout=LAYOUT_BSH,
  135. attention_layout=LAYOUT_BSH,
  136. fuse_qkv=True,
  137. quant=neuron_quantization_config_builder(model_config.quantization)
  138. if model_config.quantization else None,
  139. continuous_batching=continuous_batching_config,
  140. weight_tiling=bool(model_config.quantization))
  141. return default_neuron_args
  142. def _get_neuron_config_after_override(default_neuron_config,
  143. overridden_neuron_config):
  144. from transformers_neuronx.config import NeuronConfig
  145. overridden_neuron_config = overridden_neuron_config or {}
  146. default_neuron_config.update(overridden_neuron_config)
  147. return NeuronConfig(**default_neuron_config)
  148. def get_neuron_model(model_config: ModelConfig,
  149. parallel_config: ParallelConfig,
  150. scheduler_config: SchedulerConfig) -> nn.Module:
  151. # Create a model instance.
  152. model = NeuronCasualLM(model_config.hf_config)
  153. default_neuron_config_args = _get_default_neuron_config(
  154. model_config, parallel_config, scheduler_config)
  155. neuron_config = _get_neuron_config_after_override(
  156. default_neuron_config_args, model_config.override_neuron_config)
  157. context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
  158. [scheduler_config.max_model_len])
  159. n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
  160. [scheduler_config.max_model_len])
  161. # Load the weights from the cached or downloaded files.
  162. model.load_weights(model_config.model,
  163. tp_degree=parallel_config.tensor_parallel_size,
  164. amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
  165. neuron_config=neuron_config,
  166. context_length_estimate=context_length_estimates,
  167. n_positions=n_positions,
  168. batch_size=scheduler_config.max_num_seqs)
  169. return model.eval()