1
0

neuron_loader.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """Utilities for selecting and loading neuron models."""
  2. import importlib
  3. import os
  4. from typing import Optional, Type
  5. import torch
  6. import torch.nn as nn
  7. import transformers
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import ModelConfig, ParallelConfig, SchedulerConfig
  10. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  11. from aphrodite.modeling.layers.sampler import Sampler
  12. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  13. from aphrodite.common.sequence import SamplerOutput
  14. TORCH_DTYPE_TO_NEURON_AMP = {
  15. "auto": "f32",
  16. "half": "f16",
  17. "float16": "f16",
  18. "bfloat16": "bf16",
  19. "float": "f32",
  20. "float32": "f32",
  21. torch.float16: "f16",
  22. torch.bfloat16: "bf16",
  23. torch.float32: "f32",
  24. }
  25. # Models supported by Neuron.
  26. _NEURON_SUPPORTED_MODELS = {
  27. "LlamaForCausalLM": ("transformers_neuronx.llama.model",
  28. "LlamaForSampling", "LlamaForCausalLM"),
  29. "MistralForCausalLM": ("transformers_neuronx.mistral.model",
  30. "MistralForSampling", "MistralForCausalLM")
  31. }
  32. class NeuronCasualLM(nn.Module):
  33. def __init__(
  34. self,
  35. config: PretrainedConfig,
  36. ) -> None:
  37. super().__init__()
  38. self.config = config
  39. self.model = None
  40. self.logits_processor = LogitsProcessor(config.vocab_size,
  41. logits_as_input=True)
  42. self.sampler = Sampler()
  43. def forward(
  44. self,
  45. input_ids: torch.Tensor,
  46. positions: torch.Tensor,
  47. input_block_ids: torch.Tensor,
  48. ) -> torch.Tensor:
  49. logits = self.model(input_ids,
  50. cache_ids=positions,
  51. start_ids=input_block_ids)
  52. return logits
  53. def compute_logits(self, hidden_states: torch.Tensor,
  54. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  55. logits = self.logits_processor(None, hidden_states, sampling_metadata)
  56. return logits
  57. def sample(
  58. self,
  59. logits: torch.Tensor,
  60. sampling_metadata: SamplingMetadata,
  61. ) -> Optional[SamplerOutput]:
  62. next_tokens = self.sampler(logits, sampling_metadata)
  63. return next_tokens
  64. def load_weights(self, model_name_or_path: str, **kwargs):
  65. arch = _get_model_architecture(self.config)
  66. neuronx_module_path, neuronx_model_cls, hf_model_cls = (
  67. _NEURON_SUPPORTED_MODELS[arch])
  68. neuronx_module = importlib.import_module(neuronx_module_path)
  69. neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
  70. split_model_dir = f"{model_name_or_path}-split"
  71. if os.path.isdir(os.path.join(model_name_or_path,
  72. "pytorch_model.bin")):
  73. split_model_dir = model_name_or_path
  74. elif not os.path.exists(f"{model_name_or_path}-split"):
  75. hf_model_cls = getattr(transformers, hf_model_cls)
  76. from transformers_neuronx.module import save_pretrained_split
  77. hf_model = hf_model_cls.from_pretrained(model_name_or_path,
  78. low_cpu_mem_usage=True)
  79. save_pretrained_split(hf_model, f"{model_name_or_path}-split")
  80. self.model = neuronx_model_cls.from_pretrained(split_model_dir,
  81. **kwargs)
  82. self.model.to_neuron()
  83. def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
  84. architectures = getattr(config, "architectures", [])
  85. for arch in architectures:
  86. if arch in _NEURON_SUPPORTED_MODELS:
  87. return arch
  88. raise ValueError(
  89. f"Model architectures {architectures} are not supported on Neuron "
  90. f"for now. Supported architectures: "
  91. f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
  92. def get_neuron_model(model_config: ModelConfig,
  93. parallel_config: ParallelConfig,
  94. scheduler_config: SchedulerConfig) -> nn.Module:
  95. from transformers_neuronx.config import (NeuronConfig,
  96. ContinuousBatchingConfig)
  97. # Create a model instance.
  98. model = NeuronCasualLM(model_config.hf_config)
  99. continuous_batching_config = ContinuousBatchingConfig(
  100. batch_size_for_shared_caches=scheduler_config.max_num_seqs)
  101. neuron_config = NeuronConfig(
  102. continuous_batching=continuous_batching_config)
  103. # Load the weights from the cached or downloaded files.
  104. model.load_weights(
  105. model_config.model,
  106. tp_degree=parallel_config.tensor_parallel_size,
  107. amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
  108. neuron_config=neuron_config,
  109. context_length_estimate=[scheduler_config.max_model_len],
  110. n_positions=[scheduler_config.max_model_len],
  111. batch_size=scheduler_config.max_num_seqs,
  112. )
  113. return model.eval()